diff --git a/README.md b/README.md index 777c82cfef5aa1b03038cbad9f0c5bb30311576c..544b5544830a131e9c00d33fbe75ebee6fbf694d 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ [![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity) [![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues) [![codecov](https://codecov.io/gh/jrzaurin/pytorch-widedeep/branch/master/graph/badge.svg)](https://codecov.io/gh/jrzaurin/pytorch-widedeep) -[![Python 3.6 3.7 3.8 3.9](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8%20%7C%203.9-blue.svg)](https://www.python.org/) +[![Python 3.7 3.8 3.9](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9-blue.svg)](https://www.python.org/) # pytorch-widedeep @@ -24,6 +24,13 @@ using wide and deep models. **Slack**: if you want to contribute or just want to chat with us, join [slack](https://join.slack.com/t/pytorch-widedeep/shared_invite/zt-soss7stf-iXpVuLeKZz8lGTnxxtHtTw) +The content of this document is organized as follows: + +1. [introduction](#introduction) +2. [The deeptabular component](#the-deeptabular-component) +3. [installation](#installation) +4. [quick start (tl;dr)](#quick-start) + ### Introduction ``pytorch-widedeep`` is based on Google's [Wide and Deep Algorithm](https://arxiv.org/abs/1606.07792) @@ -82,61 +89,58 @@ into:

+I recommend using the ``wide`` and ``deeptabular`` models in +``pytorch-widedeep``. However it is very likely that users will want to use +their own models for the ``deeptext`` and ``deepimage`` components. That is +perfectly possible as long as the the custom models have an attribute called +``output_dim`` with the size of the last layer of activations, so that +``WideDeep`` can be constructed. Again, examples on how to use custom +components can be found in the Examples folder. Just in case +``pytorch-widedeep`` includes standard text (stack of LSTMs) and image +(pre-trained ResNets or stack of CNNs) models. + +### The ``deeptabular`` component + It is important to emphasize that **each individual component, `wide`, `deeptabular`, `deeptext` and `deepimage`, can be used independently** and in isolation. For example, one could use only `wide`, which is in simply a linear model. In fact, one of the most interesting functionalities -in``pytorch-widedeep`` is the ``deeptabular`` component. Currently, -``pytorch-widedeep`` offers the following different models for that -component: +in``pytorch-widedeep`` would be the use of the ``deeptabular`` component on +its own, i.e. what one might normally refer as Deep Learning for Tabular +Data. Currently, ``pytorch-widedeep`` offers the following different models +for that component: -1. ``TabMlp``: this is almost identical to the [tabular -model](https://docs.fast.ai/tutorial.tabular.html) in the fantastic -[fastai](https://docs.fast.ai/) library, and consists simply in embeddings -representing the categorical features, concatenated with the continuous -features, and passed then through a MLP. -2. ``TabRenset``: This is similar to the previous model but the embeddings are +1. **TabMlp**: a simple MLP that receives embeddings representing the +categorical features, concatenated with the continuous features. +2. **TabResnet**: similar to the previous model but the embeddings are passed through a series of ResNet blocks built with dense layers. - -3. ``Tabnet``: Details on TabNet can be found in: +3. **TabNet**: details on TabNet can be found in [TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442) -4. ``TabTransformer``: Details on the TabTransformer can be found in: -[TabTransformer: Tabular Data Modeling Using Contextual Embeddings](https://arxiv.org/pdf/2012.06678.pdf). -Note that the TabTransformer implementation available at ``pytorch-widedeep`` -is an adaptation of the original implementation. +And the ``Tabformer`` family, i.e. Transformers for Tabular data: -5. ``FT-Transformer``: or Feature Tokenizer transformer. This is a relatively small -variation of the ``TabTransformer``. The variation itself was first -introduced in the ``SAINT`` paper, but the name "``FT-Transformer``" was first -used in +4. **TabTransformer**: details on the TabTransformer can be found in +[TabTransformer: Tabular Data Modeling Using Contextual Embeddings](https://arxiv.org/pdf/2012.06678.pdf). +5. **SAINT**: Details on SAINT can be found in +[SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training](https://arxiv.org/abs/2106.01342). +6. **FT-Transformer**: details on the FT-Transformer can be found in [Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/abs/2106.11959). -When using the ``FT-Transformer`` each continuous feature is "embedded" -(i.e. going through a 1-layer MLP with or without activation function) and -then passed through the attention blocks along with the categorical features. -This is available in ``pytorch-widedeep``'s ``TabTransformer`` by setting the -parameter ``embed_continuous = True``. - +7. **TabFastFormer**: adaptation of the FastFormer for tabular data. Details +on the Fasformer can be found in +[FastFormers: Highly Efficient Transformer Models for Natural Language Understanding](https://arxiv.org/abs/2010.13382) +8. **TabPerceiver**: adaptation of the Perceiver for tabular data. Details on +the Perceiver can be found in +[Perceiver: General Perception with Iterative Attention](https://arxiv.org/abs/2103.03206) -6. ``SAINT``: Details on SAINT can be found in: -[SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training](https://arxiv.org/abs/2106.01342). +Note that while there are scientific publications for the TabTransformer, +SAINT and FT-Transformer, the TabFasfFormer and TabPerceiver are our own +adaptation of those algorithms for tabular data. For details on these models and their options please see the examples in the Examples folder and the documentation. -Finally, while I recommend using the ``wide`` and ``deeptabular`` models in -``pytorch-widedeep`` it is very likely that users will want to use their own -models for the ``deeptext`` and ``deepimage`` components. That is perfectly -possible as long as the the custom models have an attribute called -``output_dim`` with the size of the last layer of activations, so that -``WideDeep`` can be constructed. Again, examples on how to use custom -components can be found in the Examples folder. Just in case -``pytorch-widedeep`` includes standard text (stack of LSTMs) and image -(pre-trained ResNets or stack of CNNs) models. - - -### Installation +### Installation Install using pip: @@ -167,8 +171,8 @@ when running on Mac, present in previous versions, persist on this release and the data-loaders will not run in parallel. In addition, since `python 3.8`, [the `multiprocessing` library start method changed from `'fork'` to`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods). This also affects the data-loaders (for any `torch` version) and they will -not run in parallel. Therefore, for Mac users I recommend using `python -3.6` or `3.7` and `torch <= 1.6` (with the corresponding, consistent +not run in parallel. Therefore, for Mac users I recommend using `python 3.7` +and `torch <= 1.6` (with the corresponding, consistent version of `torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to force this versioning in the `setup.py` file since I expect that all these issues are fixed in the future. Therefore, after installing diff --git a/VERSION b/VERSION index 1464c521f9e176df34fe6fbd7cb6712211c03724..e5a4a5e7d84da0d6e54ff66fa898ca48d847bc04 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.0.5 \ No newline at end of file +1.0.9 \ No newline at end of file diff --git a/docs/examples.rst b/docs/examples.rst index c1e83a0efc7e8ea108d7fd8eeede0441aa5f5f2b..3af3f85a3c40e689d0d885a7920c036a022920f6 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -16,3 +16,4 @@ them to address different problems * `Save and Load Model and Artifacts `__ * `Using Custom DataLoaders and Torchmetrics `__ * `The Transformer Family `__ +* `Extracting Embeddings `__ diff --git a/docs/index.rst b/docs/index.rst index 96659654795ede993088ef4ec9b20dd14920334e..67f70f6b6cc772b1c310bb9c0d2441ec4123f0af 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -23,6 +23,7 @@ Documentation Dataloaders Callbacks The Trainer + Tab2Vec Examples diff --git a/docs/losses.rst b/docs/losses.rst index 726a80594c0efcc3ad34252ac502949b50c585dc..d2b6ad738d22dce302c7df303730d17e67fa60f5 100644 --- a/docs/losses.rst +++ b/docs/losses.rst @@ -17,8 +17,8 @@ on their own and can be imported as: from pytorch_widedeep.losses import FocalLoss .. note:: Losses in this module expect the predictions and ground truth to have the - same dimensions for regression and binary classification problems (i.e. - :math:`N_{samples}, 1)`. In the case of multiclass classification problems + same dimensions for regression and binary classification problems + :math:`(N_{samples}, 1)`. In the case of multiclass classification problems the ground truth is expected to be a 1D tensor with the corresponding classes. See Examples below diff --git a/docs/metrics.rst b/docs/metrics.rst index 4869f465cda0f86358f8272bd1ac931ca15f0583..0d0e84c40c0871a3e7bfe6b5dd93ded9b5fd6d5f 100644 --- a/docs/metrics.rst +++ b/docs/metrics.rst @@ -2,10 +2,9 @@ Metrics ======= .. note:: Metrics in this module expect the predictions and ground truth to have the - same dimensions for regression and binary classification problems (i.e. - :math:`N_{samples}, 1)`. In the case of multiclass classification problems the - ground truth is expected to be a 1D tensor with the corresponding classes. - See Examples below + same dimensions for regression and binary classification problems: :math:`(N_{samples}, 1)`. + In the case of multiclass classification problems the ground truth is expected to be + a 1D tensor with the corresponding classes. See Examples below We have added the possibility of using the metrics available at the `torchmetrics `_ library. diff --git a/docs/model_components.rst b/docs/model_components.rst index 7f68b5ffb775df60adb49285b5ea96cfb84969d0..4afadfcbfb6e7726ffb6c41ca48fc5f78ff9f198 100644 --- a/docs/model_components.rst +++ b/docs/model_components.rst @@ -5,9 +5,10 @@ This module contains the four main components that will comprise a Wide and Deep model, and the ``WideDeep`` "constructor" class. These four components are: ``wide``, ``deeptabular``, ``deeptext``, ``deepimage``. -.. note:: ``TabMlp``, ``TabResnet``, ``TabNet``, ``TabTransformer`` and ``SAINT`` can - all be used as the ``deeptabular`` component of the model and simply - represent different alternatives +.. note:: ``TabMlp``, ``TabResnet``, ``TabNet``, ``TabTransformer``, ``SAINT``, + ``FTTransformer``, ``TabPerceiver`` and ``TabFastFormer`` can all be used + as the ``deeptabular`` component of the model and simply represent different + alternatives .. autoclass:: pytorch_widedeep.models.wide.Wide :exclude-members: forward @@ -33,6 +34,18 @@ are: ``wide``, ``deeptabular``, ``deeptext``, ``deepimage``. :exclude-members: forward :members: +.. autoclass:: pytorch_widedeep.models.transformers.ft_transformer.FTTransformer + :exclude-members: forward + :members: + +.. autoclass:: pytorch_widedeep.models.transformers.tab_perceiver.TabPerceiver + :exclude-members: forward + :members: + +.. autoclass:: pytorch_widedeep.models.transformers.tab_fastformer.TabFastFormer + :exclude-members: forward + :members: + .. autoclass:: pytorch_widedeep.models.deep_text.DeepText :exclude-members: forward :members: diff --git a/docs/tab2vec.rst b/docs/tab2vec.rst new file mode 100644 index 0000000000000000000000000000000000000000..8daf940c62f68bd47517431fa1bbefe2c188d786 --- /dev/null +++ b/docs/tab2vec.rst @@ -0,0 +1,7 @@ +Tab2Vec +======= + +.. autoclass:: pytorch_widedeep.tab2vec.Tab2Vec + :members: + :undoc-members: + diff --git a/examples/02_2_deeptabular_models.ipynb b/examples/02_2_deeptabular_models.ipynb index 381c35bb1a990c2f43eb4b0bbb76e1035e7f1814..37b439eed404c4db0d009ed2da2283c428b46e5e 100644 --- a/examples/02_2_deeptabular_models.ipynb +++ b/examples/02_2_deeptabular_models.ipynb @@ -96,14 +96,16 @@ "data": { "text/plain": [ "TabMlp(\n", - " (embed_layers): ModuleDict(\n", - " (emb_layer_a): Embedding(5, 8, padding_idx=0)\n", - " (emb_layer_b): Embedding(5, 8, padding_idx=0)\n", - " (emb_layer_c): Embedding(5, 8, padding_idx=0)\n", - " (emb_layer_d): Embedding(5, 8, padding_idx=0)\n", + " (cat_embed_and_cont): CatEmbeddingsAndCont(\n", + " (embed_layers): ModuleDict(\n", + " (emb_layer_a): Embedding(5, 8, padding_idx=0)\n", + " (emb_layer_b): Embedding(5, 8, padding_idx=0)\n", + " (emb_layer_c): Embedding(5, 8, padding_idx=0)\n", + " (emb_layer_d): Embedding(5, 8, padding_idx=0)\n", + " )\n", + " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", + " (cont_norm): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", - " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", - " (cont_norm): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (tab_mlp): MLP(\n", " (mlp): Sequential(\n", " (dense_layer_0): Sequential(\n", @@ -136,7 +138,7 @@ "source": [ "Note that the input dimension of the MLP is `33`, `32` from the embeddings and `1` for the continuous features. Before we move on, is worth commenting an aspect that applies to all models discussed here. The `TabPreprocessor` included in this package gives the user the possibility of standarising the input via `sklearn`'s `StandardScaler`. Alternatively, or in addition to it, it is possible to add a continuous normalization layer (`BatchNorm1d` or `LayerNorm`). To do so simply set the `cont_norm_layer` as indicated in the example above. See also the docs.\n", "\n", - "I will insist on this in this and the following sections. Note that `TabMlp` (or any of the wide and deep components) does not build the final connection with the final neuron(s). This is done by the ``WideDeep`` class, which collects all wide and deep components and connects them to the output neuron(s).\n", + "I will insist on this in this in here and the following sections. Note that `TabMlp` (or any of the wide and deep components) does not build the final connection with the final neuron(s). This is done by the ``WideDeep`` class, which collects all wide and deep components and connects them to the output neuron(s).\n", "\n", "For example:" ] @@ -170,14 +172,16 @@ "WideDeep(\n", " (deeptabular): Sequential(\n", " (0): TabMlp(\n", - " (embed_layers): ModuleDict(\n", - " (emb_layer_a): Embedding(5, 8, padding_idx=0)\n", - " (emb_layer_b): Embedding(5, 8, padding_idx=0)\n", - " (emb_layer_c): Embedding(5, 8, padding_idx=0)\n", - " (emb_layer_d): Embedding(5, 8, padding_idx=0)\n", - " )\n", - " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", - " (cont_norm): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (cat_embed_and_cont): CatEmbeddingsAndCont(\n", + " (embed_layers): ModuleDict(\n", + " (emb_layer_a): Embedding(5, 8, padding_idx=0)\n", + " (emb_layer_b): Embedding(5, 8, padding_idx=0)\n", + " (emb_layer_c): Embedding(5, 8, padding_idx=0)\n", + " (emb_layer_d): Embedding(5, 8, padding_idx=0)\n", + " )\n", + " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", + " (cont_norm): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", " (tab_mlp): MLP(\n", " (mlp): Sequential(\n", " (dense_layer_0): Sequential(\n", @@ -283,14 +287,16 @@ "data": { "text/plain": [ "TabResnet(\n", - " (embed_layers): ModuleDict(\n", - " (emb_layer_a): Embedding(5, 8, padding_idx=0)\n", - " (emb_layer_b): Embedding(5, 8, padding_idx=0)\n", - " (emb_layer_c): Embedding(5, 8, padding_idx=0)\n", - " (emb_layer_d): Embedding(5, 8, padding_idx=0)\n", + " (cat_embed_and_cont): CatEmbeddingsAndCont(\n", + " (embed_layers): ModuleDict(\n", + " (emb_layer_a): Embedding(5, 8, padding_idx=0)\n", + " (emb_layer_b): Embedding(5, 8, padding_idx=0)\n", + " (emb_layer_c): Embedding(5, 8, padding_idx=0)\n", + " (emb_layer_d): Embedding(5, 8, padding_idx=0)\n", + " )\n", + " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", + " (cont_norm): LayerNorm((1,), eps=1e-05, elementwise_affine=True)\n", " )\n", - " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", - " (cont_norm): LayerNorm((1,), eps=1e-05, elementwise_affine=True)\n", " (tab_resnet_blks): DenseResnet(\n", " (dense_resnet): Sequential(\n", " (lin1): Linear(in_features=32, out_features=16, bias=True)\n", @@ -420,7 +426,7 @@ "data": { "text/plain": [ "TabNet(\n", - " (embed_and_cont): EmbeddingsAndContinuous(\n", + " (cat_embed_and_cont): CatEmbeddingsAndCont(\n", " (embed_layers): ModuleDict(\n", " (emb_layer_a): Embedding(5, 8, padding_idx=0)\n", " (emb_layer_b): Embedding(5, 8, padding_idx=0)\n", @@ -428,7 +434,7 @@ " (emb_layer_d): Embedding(5, 8, padding_idx=0)\n", " )\n", " (embedding_dropout): Dropout(p=0.0, inplace=False)\n", - " (cont_norm): Identity()\n", + " (cont_norm): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (tabnet_encoder): TabNetEncoder(\n", " (initial_bn): BatchNorm1d(33, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)\n", @@ -588,11 +594,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 4 and 5. `TabTransformer` and the `Feature-Tokenizer Transformer`\n", + "# 4 The transformers family\n", + "\n", + "For a tour on all transformer-based models, please, see the Transformer Family Notebook. All the content below is in that notebook.\n", + "\n", + "## 4.1 `TabTransformer` \n", "\n", "Details on the `TabTransformer` can be found in [TabTransformer: Tabular Data Modeling\n", - "Using Contextual Embeddings](https://arxiv.org/pdf/2012.06678.pdf). The `FT-Transformer` is a variant introduced in the following two papers: [SAINT: Improved Neural Networks for Tabular Data\n", - "via Row Attention and Contrastive Pre-Training](https://arxiv.org/pdf/2106.01342.pdf) and [Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/pdf/2106.11959.pdf). The name itself (`FT-Transformer`) was first used in the latter, but the variant (which I will explain in a second) was already introduced in the `SAINT` paper. \n", + "Using Contextual Embeddings](https://arxiv.org/pdf/2012.06678.pdf).\n", "\n", "In general terms, the `TabTransformer` takes the embeddings from the categorical columns that are then passed through a Tranformer encoder, concatenated with the normalised continuous features, and then passed through an MLP. Let's have a look:\n", "\n", @@ -614,7 +623,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Using the `FT-Transformer` with `pytorch-widedeep` is simply available by setting the param `embed_continuous` to `True`. In addition, I have also added the possibility of pooling all outputs from the transformer blocks using the `[CLS]` token. Otherwise all the outputs form the transformer blocks will be concatenated. Look at some of the other example notebooks for more details. " + "A variant of the `TabTransformer` is the `FT-Transformer`, which was introduced in is a variant introduced in [Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/pdf/2106.11959.pdf). The two main additions were continuous embeddings and Linear Attention. \n", + "\n", + "Continuous embeddings were already introduce in the `SAINT` paper: [SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training](https://arxiv.org/pdf/2106.01342.pdf).\n", + "\n", + "There is a dedicated `FTTransformer` model in the library that one can check in the `Transformers Family` notebook. Nonetheless, using the `TabTransformer` with continuous embeddings is as easy as setting the param `embed_continuous` to `True`. In addition, I have also added the possibility of pooling all outputs from the transformer blocks using the `[CLS]` token. Otherwise all the outputs form the transformer blocks will be concatenated. Look at some of the other example notebooks for more details. " ] }, { @@ -661,57 +674,20 @@ "data": { "text/plain": [ "TabTransformer(\n", - " (cat_embed): Embedding(17, 32, padding_idx=0)\n", - " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", - " (cont_norm): Identity()\n", - " (transformer_blks): Sequential(\n", - " (block0): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", - " )\n", - " (ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", - " )\n", - " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", + " (cat_and_cont_embed): CatAndContEmbeddings(\n", + " (cat_embed): CategoricalEmbeddings(\n", + " (embed): Embedding(17, 32, padding_idx=0)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", - " (block1): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", - " )\n", - " (ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", - " )\n", - " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - " (block2): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (cont_norm): Identity()\n", + " )\n", + " (transformer_blks): Sequential(\n", + " (transformer_block0): TransformerEncoder(\n", + " (attn): MultiHeadedAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", " (ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", @@ -720,7 +696,7 @@ " (activation): GELU()\n", " )\n", " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (ff_addnorm): AddNorm(\n", @@ -728,11 +704,12 @@ " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", - " (block3): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (transformer_block1): TransformerEncoder(\n", + " (attn): MultiHeadedAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", " (ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", @@ -741,7 +718,7 @@ " (activation): GELU()\n", " )\n", " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (ff_addnorm): AddNorm(\n", @@ -749,11 +726,12 @@ " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", - " (block4): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (transformer_block2): TransformerEncoder(\n", + " (attn): MultiHeadedAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", " (ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", @@ -762,7 +740,7 @@ " (activation): GELU()\n", " )\n", " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (ff_addnorm): AddNorm(\n", @@ -770,11 +748,12 @@ " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", - " (block5): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (transformer_block3): TransformerEncoder(\n", + " (attn): MultiHeadedAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", " (ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", @@ -783,7 +762,7 @@ " (activation): GELU()\n", " )\n", " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (ff_addnorm): AddNorm(\n", @@ -824,14 +803,14 @@ "metadata": {}, "outputs": [], "source": [ - "ft_transformer = TabTransformer(\n", + "tab_transformer = TabTransformer(\n", " column_idx=column_idx, \n", " embed_input=embed_input, \n", " continuous_cols=continuous_cols,\n", " embed_continuous=True,\n", " embed_continuous_activation=\"relu\",\n", ")\n", - "out = ft_transformer(X_tab) " + "out = tab_transformer(X_tab) " ] }, { @@ -845,60 +824,23 @@ "data": { "text/plain": [ "TabTransformer(\n", - " (cat_embed): Embedding(17, 32, padding_idx=0)\n", - " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", - " (cont_norm): Identity()\n", - " (cont_embed): ContinuousEmbeddings(\n", - " (act_fn): ReLU(inplace=True)\n", - " )\n", - " (transformer_blks): Sequential(\n", - " (block0): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", - " )\n", - " (ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", - " )\n", - " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", + " (cat_and_cont_embed): CatAndContEmbeddings(\n", + " (cat_embed): CategoricalEmbeddings(\n", + " (embed): Embedding(17, 32, padding_idx=0)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", - " (block1): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", - " )\n", - " (ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", - " )\n", - " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", + " (cont_norm): Identity()\n", + " (cont_embed): ContinuousEmbeddings(\n", + " (act_fn): ReLU(inplace=True)\n", " )\n", - " (block2): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " )\n", + " (transformer_blks): Sequential(\n", + " (transformer_block0): TransformerEncoder(\n", + " (attn): MultiHeadedAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", " (ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", @@ -907,7 +849,7 @@ " (activation): GELU()\n", " )\n", " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (ff_addnorm): AddNorm(\n", @@ -915,11 +857,12 @@ " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", - " (block3): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (transformer_block1): TransformerEncoder(\n", + " (attn): MultiHeadedAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", " (ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", @@ -928,7 +871,7 @@ " (activation): GELU()\n", " )\n", " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (ff_addnorm): AddNorm(\n", @@ -936,11 +879,12 @@ " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", - " (block4): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (transformer_block2): TransformerEncoder(\n", + " (attn): MultiHeadedAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", " (ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", @@ -949,7 +893,7 @@ " (activation): GELU()\n", " )\n", " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (ff_addnorm): AddNorm(\n", @@ -957,11 +901,12 @@ " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", - " (block5): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (transformer_block3): TransformerEncoder(\n", + " (attn): MultiHeadedAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", " (ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", @@ -970,7 +915,7 @@ " (activation): GELU()\n", " )\n", " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (ff_addnorm): AddNorm(\n", @@ -1002,7 +947,7 @@ } ], "source": [ - "ft_transformer" + "tab_transformer" ] }, { @@ -1065,7 +1010,6 @@ " column_idx=column_idx, \n", " embed_input=embed_input, \n", " continuous_cols=continuous_cols,\n", - " embed_continuous=True,\n", " embed_continuous_activation=\"leaky_relu\",\n", ")\n", "out = saint(X_tab) " @@ -1082,202 +1026,48 @@ "data": { "text/plain": [ "SAINT(\n", - " (cat_embed): Embedding(17, 32, padding_idx=0)\n", - " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", - " (cont_norm): LayerNorm((1,), eps=1e-05, elementwise_affine=True)\n", - " (cont_embed): ContinuousEmbeddings(\n", - " (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (transformer_blks): Sequential(\n", - " (block0): SaintEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", - " )\n", - " (self_attn_ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", - " )\n", - " (self_attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (self_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (row_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=160, out_features=480, bias=True)\n", - " (out_proj): Linear(in_features=160, out_features=160, bias=True)\n", - " )\n", - " (row_attn_ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=160, out_features=640, bias=True)\n", - " (w_2): Linear(in_features=640, out_features=160, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", - " )\n", - " (row_attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((160,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (row_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((160,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - " (block1): SaintEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", - " )\n", - " (self_attn_ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", - " )\n", - " (self_attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (self_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (row_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=160, out_features=480, bias=True)\n", - " (out_proj): Linear(in_features=160, out_features=160, bias=True)\n", - " )\n", - " (row_attn_ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=160, out_features=640, bias=True)\n", - " (w_2): Linear(in_features=640, out_features=160, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", - " )\n", - " (row_attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((160,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (row_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((160,), eps=1e-05, elementwise_affine=True)\n", - " )\n", + " (cat_and_cont_embed): CatAndContEmbeddings(\n", + " (cat_embed): CategoricalEmbeddings(\n", + " (embed): Embedding(17, 32, padding_idx=0)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", - " (block2): SaintEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", - " )\n", - " (self_attn_ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", - " )\n", - " (self_attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (self_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (row_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=160, out_features=480, bias=True)\n", - " (out_proj): Linear(in_features=160, out_features=160, bias=True)\n", - " )\n", - " (row_attn_ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=160, out_features=640, bias=True)\n", - " (w_2): Linear(in_features=640, out_features=160, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", - " )\n", - " (row_attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((160,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (row_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((160,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - " (block3): SaintEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", - " )\n", - " (self_attn_ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", - " )\n", - " (self_attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (self_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (row_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=160, out_features=480, bias=True)\n", - " (out_proj): Linear(in_features=160, out_features=160, bias=True)\n", - " )\n", - " (row_attn_ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=160, out_features=640, bias=True)\n", - " (w_2): Linear(in_features=640, out_features=160, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", - " )\n", - " (row_attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((160,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (row_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((160,), eps=1e-05, elementwise_affine=True)\n", - " )\n", + " (cont_norm): Identity()\n", + " (cont_embed): ContinuousEmbeddings(\n", + " (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", - " (block4): SaintEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", + " )\n", + " (transformer_blks): Sequential(\n", + " (saint_block0): SaintEncoder(\n", + " (col_attn): MultiHeadedAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", - " (self_attn_ff): PositionwiseFF(\n", + " (col_attn_ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): GELU()\n", " )\n", - " (self_attn_addnorm): AddNorm(\n", + " (col_attn_addnorm): AddNorm(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", - " (self_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (col_attn_ff_addnorm): AddNorm(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (row_attn): MultiHeadedAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=160, out_features=480, bias=True)\n", - " (out_proj): Linear(in_features=160, out_features=160, bias=True)\n", + " (q_proj): Linear(in_features=160, out_features=160, bias=False)\n", + " (kv_proj): Linear(in_features=160, out_features=320, bias=False)\n", + " (out_proj): Linear(in_features=160, out_features=160, bias=False)\n", " )\n", " (row_attn_ff): PositionwiseFF(\n", " (w_1): Linear(in_features=160, out_features=640, bias=True)\n", " (w_2): Linear(in_features=640, out_features=160, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): GELU()\n", " )\n", " (row_attn_addnorm): AddNorm(\n", @@ -1285,39 +1075,41 @@ " (ln): LayerNorm((160,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (row_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((160,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", - " (block5): SaintEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", + " (saint_block1): SaintEncoder(\n", + " (col_attn): MultiHeadedAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", - " (self_attn_ff): PositionwiseFF(\n", + " (col_attn_ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): GELU()\n", " )\n", - " (self_attn_addnorm): AddNorm(\n", + " (col_attn_addnorm): AddNorm(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", - " (self_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (col_attn_ff_addnorm): AddNorm(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (row_attn): MultiHeadedAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=160, out_features=480, bias=True)\n", - " (out_proj): Linear(in_features=160, out_features=160, bias=True)\n", + " (q_proj): Linear(in_features=160, out_features=160, bias=False)\n", + " (kv_proj): Linear(in_features=160, out_features=320, bias=False)\n", + " (out_proj): Linear(in_features=160, out_features=160, bias=False)\n", " )\n", " (row_attn_ff): PositionwiseFF(\n", " (w_1): Linear(in_features=160, out_features=640, bias=True)\n", " (w_2): Linear(in_features=640, out_features=160, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): GELU()\n", " )\n", " (row_attn_addnorm): AddNorm(\n", @@ -1325,7 +1117,7 @@ " (ln): LayerNorm((160,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (row_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((160,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", diff --git a/examples/03_Binary_Classification_with_Defaults.ipynb b/examples/03_Binary_Classification_with_Defaults.ipynb index edeadc9c38fda3a717e7e5f9cbb4b0ffd7b524f1..1f0e2071df761214f1c017e4efb77ac4a0e3a865 100644 --- a/examples/03_Binary_Classification_with_Defaults.ipynb +++ b/examples/03_Binary_Classification_with_Defaults.ipynb @@ -511,15 +511,17 @@ " )\n", " (deeptabular): Sequential(\n", " (0): TabMlp(\n", - " (embed_layers): ModuleDict(\n", - " (emb_layer_education): Embedding(17, 16, padding_idx=0)\n", - " (emb_layer_native_country): Embedding(43, 16, padding_idx=0)\n", - " (emb_layer_occupation): Embedding(16, 16, padding_idx=0)\n", - " (emb_layer_relationship): Embedding(7, 8, padding_idx=0)\n", - " (emb_layer_workclass): Embedding(10, 16, padding_idx=0)\n", + " (cat_embed_and_cont): CatEmbeddingsAndCont(\n", + " (embed_layers): ModuleDict(\n", + " (emb_layer_education): Embedding(17, 16, padding_idx=0)\n", + " (emb_layer_native_country): Embedding(43, 16, padding_idx=0)\n", + " (emb_layer_occupation): Embedding(16, 16, padding_idx=0)\n", + " (emb_layer_relationship): Embedding(7, 8, padding_idx=0)\n", + " (emb_layer_workclass): Embedding(10, 16, padding_idx=0)\n", + " )\n", + " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", + " (cont_norm): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", - " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", - " (cont_norm): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (tab_mlp): MLP(\n", " (mlp): Sequential(\n", " (dense_layer_0): Sequential(\n", @@ -589,16 +591,16 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 611/611 [00:07<00:00, 83.83it/s, loss=0.425, metrics={'acc': 0.801, 'prec': 0.6074}] \n", - "valid: 100%|██████████| 153/153 [00:01<00:00, 129.25it/s, loss=0.362, metrics={'acc': 0.8341, 'prec': 0.6947}]\n", - "epoch 2: 100%|██████████| 611/611 [00:07<00:00, 79.96it/s, loss=0.373, metrics={'acc': 0.8245, 'prec': 0.6621}]\n", - "valid: 100%|██████████| 153/153 [00:01<00:00, 140.03it/s, loss=0.356, metrics={'acc': 0.8353, 'prec': 0.6742}]\n", - "epoch 3: 100%|██████████| 611/611 [00:07<00:00, 79.08it/s, loss=0.364, metrics={'acc': 0.8288, 'prec': 0.6729}]\n", - "valid: 100%|██████████| 153/153 [00:01<00:00, 150.18it/s, loss=0.35, metrics={'acc': 0.838, 'prec': 0.6875}] \n", - "epoch 4: 100%|██████████| 611/611 [00:07<00:00, 82.86it/s, loss=0.358, metrics={'acc': 0.8319, 'prec': 0.6814}]\n", - "valid: 100%|██████████| 153/153 [00:01<00:00, 147.48it/s, loss=0.345, metrics={'acc': 0.8394, 'prec': 0.6949}]\n", - "epoch 5: 100%|██████████| 611/611 [00:07<00:00, 78.20it/s, loss=0.354, metrics={'acc': 0.8337, 'prec': 0.6872}]\n", - "valid: 100%|██████████| 153/153 [00:01<00:00, 150.62it/s, loss=0.344, metrics={'acc': 0.8426, 'prec': 0.7066}]\n" + "epoch 1: 100%|██████████| 611/611 [00:07<00:00, 78.90it/s, loss=0.477, metrics={'acc': 0.7763, 'prec': 0.5377}]\n", + "valid: 100%|██████████| 153/153 [00:01<00:00, 124.13it/s, loss=0.387, metrics={'acc': 0.8148, 'prec': 0.6034}]\n", + "epoch 2: 100%|██████████| 611/611 [00:06<00:00, 88.99it/s, loss=0.383, metrics={'acc': 0.8205, 'prec': 0.6525}]\n", + "valid: 100%|██████████| 153/153 [00:01<00:00, 116.64it/s, loss=0.364, metrics={'acc': 0.832, 'prec': 0.6629}] \n", + "epoch 3: 100%|██████████| 611/611 [00:09<00:00, 67.26it/s, loss=0.372, metrics={'acc': 0.8264, 'prec': 0.6683}]\n", + "valid: 100%|██████████| 153/153 [00:01<00:00, 145.81it/s, loss=0.355, metrics={'acc': 0.8343, 'prec': 0.669}] \n", + "epoch 4: 100%|██████████| 611/611 [00:07<00:00, 86.61it/s, loss=0.36, metrics={'acc': 0.8306, 'prec': 0.6784}] \n", + "valid: 100%|██████████| 153/153 [00:00<00:00, 168.01it/s, loss=0.354, metrics={'acc': 0.8323, 'prec': 0.6549}]\n", + "epoch 5: 100%|██████████| 611/611 [00:06<00:00, 87.79it/s, loss=0.357, metrics={'acc': 0.8321, 'prec': 0.6841}]\n", + "valid: 100%|██████████| 153/153 [00:00<00:00, 171.45it/s, loss=0.352, metrics={'acc': 0.8341, 'prec': 0.6671}]\n" ] } ], @@ -647,16 +649,16 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 611/611 [00:08<00:00, 67.98it/s, loss=0.385, metrics={'acc': 0.8182, 'prec': 0.6465}]\n", - "valid: 100%|██████████| 153/153 [00:01<00:00, 146.93it/s, loss=0.359, metrics={'acc': 0.8361, 'prec': 0.6862}]\n", - "epoch 2: 100%|██████████| 611/611 [00:09<00:00, 67.33it/s, loss=0.363, metrics={'acc': 0.8296, 'prec': 0.6756}]\n", - "valid: 100%|██████████| 153/153 [00:01<00:00, 148.88it/s, loss=0.354, metrics={'acc': 0.8353, 'prec': 0.7058}]\n", - "epoch 3: 100%|██████████| 611/611 [00:09<00:00, 67.31it/s, loss=0.357, metrics={'acc': 0.8312, 'prec': 0.6822}]\n", - "valid: 100%|██████████| 153/153 [00:01<00:00, 142.79it/s, loss=0.351, metrics={'acc': 0.8387, 'prec': 0.6813}]\n", - "epoch 4: 100%|██████████| 611/611 [00:09<00:00, 62.17it/s, loss=0.353, metrics={'acc': 0.8347, 'prec': 0.6897}]\n", - "valid: 100%|██████████| 153/153 [00:01<00:00, 124.27it/s, loss=0.348, metrics={'acc': 0.8404, 'prec': 0.692}] \n", - "epoch 5: 100%|██████████| 611/611 [00:17<00:00, 35.55it/s, loss=0.35, metrics={'acc': 0.8347, 'prec': 0.6893}] \n", - "valid: 100%|██████████| 153/153 [00:01<00:00, 116.47it/s, loss=0.345, metrics={'acc': 0.8427, 'prec': 0.6936}]\n" + "epoch 1: 100%|██████████| 611/611 [00:08<00:00, 70.16it/s, loss=0.439, metrics={'acc': 0.7865, 'prec': 0.5561}]\n", + "valid: 100%|██████████| 153/153 [00:01<00:00, 152.45it/s, loss=0.361, metrics={'acc': 0.8349, 'prec': 0.6803}]\n", + "epoch 2: 100%|██████████| 611/611 [00:08<00:00, 70.18it/s, loss=0.373, metrics={'acc': 0.8236, 'prec': 0.6609}]\n", + "valid: 100%|██████████| 153/153 [00:00<00:00, 157.46it/s, loss=0.354, metrics={'acc': 0.839, 'prec': 0.704}] \n", + "epoch 3: 100%|██████████| 611/611 [00:08<00:00, 70.71it/s, loss=0.363, metrics={'acc': 0.8294, 'prec': 0.6717}]\n", + "valid: 100%|██████████| 153/153 [00:00<00:00, 155.59it/s, loss=0.353, metrics={'acc': 0.8381, 'prec': 0.6954}]\n", + "epoch 4: 100%|██████████| 611/611 [00:08<00:00, 69.53it/s, loss=0.358, metrics={'acc': 0.8323, 'prec': 0.683}] \n", + "valid: 100%|██████████| 153/153 [00:00<00:00, 157.64it/s, loss=0.353, metrics={'acc': 0.8339, 'prec': 0.658}] \n", + "epoch 5: 100%|██████████| 611/611 [00:08<00:00, 69.47it/s, loss=0.354, metrics={'acc': 0.8338, 'prec': 0.6851}]\n", + "valid: 100%|██████████| 153/153 [00:00<00:00, 157.29it/s, loss=0.353, metrics={'acc': 0.8375, 'prec': 0.6764}]\n" ] } ], @@ -668,7 +670,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Using the `FT-Transformer` as the `deeptabular` component" + "Using the `TabTransformer` as the `deeptabular` component" ] }, { @@ -703,7 +705,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -717,7 +719,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -732,7 +734,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -741,17 +743,17 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 77/77 [00:21<00:00, 3.57it/s, loss=0.446, metrics={'acc': 0.7905, 'prec': 0.5849}]\n", - "valid: 100%|██████████| 20/20 [00:01<00:00, 13.74it/s, loss=0.374, metrics={'acc': 0.8227, 'prec': 0.6443}]\n", - "epoch 2: 100%|██████████| 77/77 [00:21<00:00, 3.63it/s, loss=0.377, metrics={'acc': 0.8231, 'prec': 0.6586}]\n", - "valid: 100%|██████████| 20/20 [00:01<00:00, 14.04it/s, loss=0.372, metrics={'acc': 0.8216, 'prec': 0.6112}]\n" + "epoch 1: 100%|██████████| 77/77 [00:20<00:00, 3.79it/s, loss=0.667, metrics={'acc': 0.6787, 'prec': 0.3306}]\n", + "valid: 100%|██████████| 20/20 [00:01<00:00, 14.78it/s, loss=0.409, metrics={'acc': 0.8033, 'prec': 0.583}] \n", + "epoch 2: 100%|██████████| 77/77 [00:21<00:00, 3.52it/s, loss=0.403, metrics={'acc': 0.8136, 'prec': 0.6326}]\n", + "valid: 100%|██████████| 20/20 [00:01<00:00, 14.32it/s, loss=0.374, metrics={'acc': 0.8224, 'prec': 0.6504}]\n" ] } ], @@ -768,7 +770,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -777,7 +779,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -786,23 +788,23 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 611/611 [00:03<00:00, 159.50it/s, loss=0.46, metrics={'acc': 0.7836, 'prec': 0.573}] \n", - "valid: 100%|██████████| 153/153 [00:00<00:00, 168.08it/s, loss=0.422, metrics={'acc': 0.805, 'prec': 0.6403}] \n", - "epoch 2: 100%|██████████| 611/611 [00:03<00:00, 158.35it/s, loss=0.405, metrics={'acc': 0.8131, 'prec': 0.6643}]\n", - "valid: 100%|██████████| 153/153 [00:00<00:00, 214.83it/s, loss=0.394, metrics={'acc': 0.8168, 'prec': 0.6741}]\n", - "epoch 3: 100%|██████████| 611/611 [00:03<00:00, 162.96it/s, loss=0.385, metrics={'acc': 0.8201, 'prec': 0.6837}]\n", - "valid: 100%|██████████| 153/153 [00:00<00:00, 210.59it/s, loss=0.381, metrics={'acc': 0.8228, 'prec': 0.6799}]\n", - "epoch 4: 100%|██████████| 611/611 [00:03<00:00, 171.06it/s, loss=0.375, metrics={'acc': 0.8256, 'prec': 0.6895}]\n", - "valid: 100%|██████████| 153/153 [00:00<00:00, 210.97it/s, loss=0.374, metrics={'acc': 0.8259, 'prec': 0.6798}]\n", - "epoch 5: 100%|██████████| 611/611 [00:03<00:00, 157.49it/s, loss=0.369, metrics={'acc': 0.828, 'prec': 0.692}] \n", - "valid: 100%|██████████| 153/153 [00:00<00:00, 197.19it/s, loss=0.37, metrics={'acc': 0.8275, 'prec': 0.6856}] \n" + "epoch 1: 100%|██████████| 611/611 [00:03<00:00, 166.24it/s, loss=0.799, metrics={'acc': 0.5446, 'prec': 0.229}] \n", + "valid: 100%|██████████| 153/153 [00:00<00:00, 261.71it/s, loss=0.56, metrics={'acc': 0.7394, 'prec': 0.4044}] \n", + "epoch 2: 100%|██████████| 611/611 [00:03<00:00, 167.73it/s, loss=0.491, metrics={'acc': 0.7699, 'prec': 0.5417}]\n", + "valid: 100%|██████████| 153/153 [00:00<00:00, 259.04it/s, loss=0.446, metrics={'acc': 0.7939, 'prec': 0.6517}]\n", + "epoch 3: 100%|██████████| 611/611 [00:03<00:00, 168.46it/s, loss=0.425, metrics={'acc': 0.809, 'prec': 0.6809}] \n", + "valid: 100%|██████████| 153/153 [00:00<00:00, 257.24it/s, loss=0.406, metrics={'acc': 0.8201, 'prec': 0.7044}]\n", + "epoch 4: 100%|██████████| 611/611 [00:03<00:00, 169.35it/s, loss=0.397, metrics={'acc': 0.8176, 'prec': 0.6909}]\n", + "valid: 100%|██████████| 153/153 [00:00<00:00, 174.28it/s, loss=0.388, metrics={'acc': 0.8248, 'prec': 0.7048}]\n", + "epoch 5: 100%|██████████| 611/611 [00:04<00:00, 142.70it/s, loss=0.382, metrics={'acc': 0.8239, 'prec': 0.6947}]\n", + "valid: 100%|██████████| 153/153 [00:00<00:00, 207.43it/s, loss=0.378, metrics={'acc': 0.8288, 'prec': 0.6999}]\n" ] } ], diff --git a/examples/04_Binary_Classification_Varying_Parameters.ipynb b/examples/04_Binary_Classification_Varying_Parameters.ipynb index fe219de09faf8bc1142a76a18c51198fc20c24ca..23ba480cdc469fe27a0636183edc4f229ea22fc5 100644 --- a/examples/04_Binary_Classification_Varying_Parameters.ipynb +++ b/examples/04_Binary_Classification_Varying_Parameters.ipynb @@ -533,15 +533,17 @@ " )\n", " (deeptabular): Sequential(\n", " (0): TabMlp(\n", - " (embed_layers): ModuleDict(\n", - " (emb_layer_education): Embedding(17, 16, padding_idx=0)\n", - " (emb_layer_native_country): Embedding(43, 16, padding_idx=0)\n", - " (emb_layer_occupation): Embedding(16, 16, padding_idx=0)\n", - " (emb_layer_relationship): Embedding(7, 8, padding_idx=0)\n", - " (emb_layer_workclass): Embedding(10, 16, padding_idx=0)\n", + " (cat_embed_and_cont): CatEmbeddingsAndCont(\n", + " (embed_layers): ModuleDict(\n", + " (emb_layer_education): Embedding(17, 16, padding_idx=0)\n", + " (emb_layer_native_country): Embedding(43, 16, padding_idx=0)\n", + " (emb_layer_occupation): Embedding(16, 16, padding_idx=0)\n", + " (emb_layer_relationship): Embedding(7, 8, padding_idx=0)\n", + " (emb_layer_workclass): Embedding(10, 16, padding_idx=0)\n", + " )\n", + " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", + " (cont_norm): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", - " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", - " (cont_norm): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (tab_mlp): MLP(\n", " (mlp): Sequential(\n", " (dense_layer_0): Sequential(\n", @@ -658,26 +660,26 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 153/153 [00:03<00:00, 40.76it/s, loss=0.605, metrics={'acc': 0.7653, 'rec': 0.5005}]\n", - "valid: 100%|██████████| 39/39 [00:00<00:00, 70.56it/s, loss=0.37, metrics={'acc': 0.8295, 'rec': 0.5646}] \n", - "epoch 2: 100%|██████████| 153/153 [00:03<00:00, 42.82it/s, loss=0.37, metrics={'acc': 0.8298, 'rec': 0.5627}] \n", - "valid: 100%|██████████| 39/39 [00:00<00:00, 116.22it/s, loss=0.355, metrics={'acc': 0.8372, 'rec': 0.6206}]\n", - "epoch 3: 100%|██████████| 153/153 [00:03<00:00, 41.82it/s, loss=0.354, metrics={'acc': 0.8338, 'rec': 0.5612}]\n", - "valid: 100%|██████████| 39/39 [00:00<00:00, 116.42it/s, loss=0.35, metrics={'acc': 0.8395, 'rec': 0.5804}] \n", - "epoch 4: 100%|██████████| 153/153 [00:03<00:00, 42.66it/s, loss=0.345, metrics={'acc': 0.8382, 'rec': 0.5658}]\n", - "valid: 100%|██████████| 39/39 [00:00<00:00, 115.17it/s, loss=0.35, metrics={'acc': 0.8379, 'rec': 0.6048}] \n", - "epoch 5: 100%|██████████| 153/153 [00:03<00:00, 42.11it/s, loss=0.343, metrics={'acc': 0.8391, 'rec': 0.5681}]\n", - "valid: 100%|██████████| 39/39 [00:00<00:00, 115.60it/s, loss=0.347, metrics={'acc': 0.84, 'rec': 0.595}] \n", - "epoch 6: 100%|██████████| 153/153 [00:03<00:00, 41.32it/s, loss=0.341, metrics={'acc': 0.8398, 'rec': 0.5748}]\n", - "valid: 100%|██████████| 39/39 [00:00<00:00, 109.95it/s, loss=0.347, metrics={'acc': 0.8404, 'rec': 0.5855}]\n", - "epoch 7: 100%|██████████| 153/153 [00:03<00:00, 41.79it/s, loss=0.34, metrics={'acc': 0.8413, 'rec': 0.5746}] \n", - "valid: 100%|██████████| 39/39 [00:00<00:00, 108.11it/s, loss=0.347, metrics={'acc': 0.8395, 'rec': 0.5898}]\n", - "epoch 8: 100%|██████████| 153/153 [00:03<00:00, 41.09it/s, loss=0.341, metrics={'acc': 0.8395, 'rec': 0.5744}]\n", - "valid: 100%|██████████| 39/39 [00:00<00:00, 99.26it/s, loss=0.347, metrics={'acc': 0.8404, 'rec': 0.5877}]\n", - "epoch 9: 100%|██████████| 153/153 [00:03<00:00, 41.33it/s, loss=0.34, metrics={'acc': 0.8409, 'rec': 0.573}] \n", - "valid: 100%|██████████| 39/39 [00:00<00:00, 108.59it/s, loss=0.347, metrics={'acc': 0.8399, 'rec': 0.5778}]\n", - "epoch 10: 100%|██████████| 153/153 [00:03<00:00, 40.06it/s, loss=0.34, metrics={'acc': 0.8413, 'rec': 0.5718}] \n", - "valid: 100%|██████████| 39/39 [00:00<00:00, 104.13it/s, loss=0.347, metrics={'acc': 0.8395, 'rec': 0.577}] \n" + "epoch 1: 100%|██████████| 153/153 [00:03<00:00, 42.78it/s, loss=0.562, metrics={'acc': 0.7779, 'rec': 0.488}] \n", + "valid: 100%|██████████| 39/39 [00:00<00:00, 54.81it/s, loss=0.374, metrics={'acc': 0.8363, 'rec': 0.5684}]\n", + "epoch 2: 100%|██████████| 153/153 [00:03<00:00, 44.03it/s, loss=0.373, metrics={'acc': 0.8277, 'rec': 0.5535}]\n", + "valid: 100%|██████████| 39/39 [00:00<00:00, 108.54it/s, loss=0.359, metrics={'acc': 0.8361, 'rec': 0.5915}]\n", + "epoch 3: 100%|██████████| 153/153 [00:03<00:00, 41.40it/s, loss=0.354, metrics={'acc': 0.8354, 'rec': 0.5686}]\n", + "valid: 100%|██████████| 39/39 [00:00<00:00, 100.84it/s, loss=0.355, metrics={'acc': 0.8378, 'rec': 0.5346}]\n", + "epoch 4: 100%|██████████| 153/153 [00:03<00:00, 43.49it/s, loss=0.346, metrics={'acc': 0.8381, 'rec': 0.5653}]\n", + "valid: 100%|██████████| 39/39 [00:00<00:00, 117.29it/s, loss=0.352, metrics={'acc': 0.8388, 'rec': 0.5633}]\n", + "epoch 5: 100%|██████████| 153/153 [00:03<00:00, 39.83it/s, loss=0.343, metrics={'acc': 0.8396, 'rec': 0.5669}]\n", + "valid: 100%|██████████| 39/39 [00:00<00:00, 115.86it/s, loss=0.351, metrics={'acc': 0.8388, 'rec': 0.6074}]\n", + "epoch 6: 100%|██████████| 153/153 [00:03<00:00, 41.32it/s, loss=0.342, metrics={'acc': 0.8406, 'rec': 0.5758}]\n", + "valid: 100%|██████████| 39/39 [00:00<00:00, 110.53it/s, loss=0.35, metrics={'acc': 0.84, 'rec': 0.5834}] \n", + "epoch 7: 100%|██████████| 153/153 [00:03<00:00, 40.08it/s, loss=0.341, metrics={'acc': 0.8407, 'rec': 0.5664}]\n", + "valid: 100%|██████████| 39/39 [00:00<00:00, 108.04it/s, loss=0.35, metrics={'acc': 0.8399, 'rec': 0.5924}] \n", + "epoch 8: 100%|██████████| 153/153 [00:03<00:00, 40.74it/s, loss=0.341, metrics={'acc': 0.8397, 'rec': 0.573}] \n", + "valid: 100%|██████████| 39/39 [00:00<00:00, 103.97it/s, loss=0.35, metrics={'acc': 0.8404, 'rec': 0.5881}] \n", + "epoch 9: 100%|██████████| 153/153 [00:03<00:00, 41.83it/s, loss=0.341, metrics={'acc': 0.8407, 'rec': 0.571}] \n", + "valid: 100%|██████████| 39/39 [00:00<00:00, 112.66it/s, loss=0.35, metrics={'acc': 0.8398, 'rec': 0.595}] \n", + "epoch 10: 100%|██████████| 153/153 [00:03<00:00, 41.73it/s, loss=0.341, metrics={'acc': 0.8404, 'rec': 0.5751}]\n", + "valid: 100%|██████████| 39/39 [00:00<00:00, 111.89it/s, loss=0.35, metrics={'acc': 0.8389, 'rec': 0.5787}] \n" ] }, { @@ -708,7 +710,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "{'train_loss': [0.6051353691450132, 0.3695722280764112, 0.35393014296986697, 0.3445955140917909, 0.34339318926038304, 0.3414274424898858, 0.33967164684744444, 0.3409811920589871, 0.3399035648193235, 0.3402750544688281], 'train_acc': [0.7653366775010877, 0.8297545619737414, 0.8337726819031045, 0.8382002917615745, 0.8391216441020654, 0.8397614721162951, 0.8412970593504466, 0.8395055409106033, 0.8408619763007703, 0.8412970593504466], 'train_rec': [0.5004813075065613, 0.5627340078353882, 0.5612365007400513, 0.5658358931541443, 0.5680821537971497, 0.5748208165168762, 0.5746068954467773, 0.5743929743766785, 0.5730024576187134, 0.5718258619308472], 'val_loss': [0.37032382075603193, 0.35480272387846923, 0.349816164909265, 0.3500071618801508, 0.3474775743790162, 0.3471213915409186, 0.34703178054247147, 0.3469086617995531, 0.3469338050255409, 0.3465542976672833], 'val_acc': [0.8294605384379159, 0.8372402497696796, 0.8394922714709796, 0.8378544375063978, 0.8400040945849114, 0.8404135530760569, 0.8394922714709796, 0.8404135530760569, 0.8399017299621251, 0.8394922714709796], 'val_rec': [0.5645850896835327, 0.6206158995628357, 0.5804105997085571, 0.6047903895378113, 0.5949529409408569, 0.5855432152748108, 0.589820384979248, 0.587681770324707, 0.5778443217277527, 0.5769888758659363]}\n" + "{'train_loss': [0.5623695554296955, 0.3727661143330967, 0.3543393321676192, 0.3463186333382052, 0.34326155766162997, 0.34202106482063244, 0.34081082270036334, 0.34090089836930915, 0.3412071953411975, 0.3405635129002964], 'train_acc': [0.7778517134594221, 0.8277071123282062, 0.8353594553783943, 0.8381235123998669, 0.83960791339288, 0.8405804519745093, 0.8406572313362168, 0.8396590996340184, 0.840682824456786, 0.8404268932510941], 'train_rec': [0.4879666268825531, 0.5535351634025574, 0.5686169862747192, 0.5653011202812195, 0.5669055581092834, 0.5757834911346436, 0.5663707256317139, 0.5730024576187134, 0.5709701776504517, 0.5751417279243469], 'val_loss': [0.374390076368283, 0.35924087579433733, 0.354536472986906, 0.35208039711683226, 0.35081761387678295, 0.3504261534947615, 0.350106044457509, 0.34991710613935423, 0.35027056473952073, 0.34997811913490295], 'val_acc': [0.8363189681646023, 0.8361142389190296, 0.8377520728836114, 0.8387757191114751, 0.8387757191114751, 0.8400040945849114, 0.8399017299621251, 0.8404135530760569, 0.8397993653393387, 0.8388780837342614], 'val_rec': [0.5684345364570618, 0.5915312170982361, 0.5346450209617615, 0.5633019804954529, 0.6073567271232605, 0.5834046006202698, 0.5923866629600525, 0.5881094932556152, 0.5949529409408569, 0.5786997675895691]}\n" ] } ], @@ -739,7 +741,9 @@ "source": [ "We can see that the learning rate effectively decreases by a factor of 0.1 (the default) after the corresponding `step_size`. Note that the keys of the dictionary have a suffix `_0`. This is because if you pass different parameter groups to the torch optimizers, these will also be recorded. We'll see this in the `Regression` notebook. \n", "\n", - "And I guess one has a good idea of how to use the package. Before we leave this notebook just mentioning that the `WideDeep` class comes with a useful method to \"rescue\" the learned embeddings. For example, let's say I want to use the embeddings learned for the different levels of the categorical feature `education`" + "And I guess one has a good idea of how to use the package. \n", + "\n", + "Before we leave this notebook just mentioning that the `WideDeep` class comes with a what is perhaps a useful method that I intend to deprecate in favor of `Tab2Vec`. This method, called `get_embeddings` is designed to \"rescue\" the learned embeddings. For example, let's say I want to use the embeddings learned for the different levels of the categorical feature `education`" ] }, { @@ -747,73 +751,82 @@ "execution_count": 18, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/javier/Projects/pytorch-widedeep/pytorch_widedeep/training/trainer.py:794: DeprecationWarning: 'get_embeddings' will be deprecated in the next release. Please consider using 'Tab2vec' instead\n", + " DeprecationWarning,\n" + ] + }, { "data": { "text/plain": [ - "{'11th': array([ 0.17823647, 0.04097574, -0.16298912, -0.11065536, 0.1432162 ,\n", - " 0.10531982, -0.34251764, 0.40085673, 0.09578304, 0.15393786,\n", - " 0.26928946, -0.05603978, -0.2609236 , 0.0091235 , 0.07494199,\n", - " 0.02190116], dtype=float32),\n", - " 'HS-grad': array([ 0.120097 , -0.13213032, -0.0592633 , 0.04583196, -0.04858546,\n", - " -0.39242733, -0.43368143, 0.00434827, 0.04477202, 0.07125217,\n", - " -0.15088314, -0.2939101 , 0.31975606, -0.341947 , 0.22773097,\n", - " 0.28342503], dtype=float32),\n", - " 'Assoc-acdm': array([ 0.14957719, -0.18953936, -0.22840326, 0.45375347, 0.26669678,\n", - " 0.05090672, 0.46574584, 0.2774832 , -0.12203862, 0.13699052,\n", - " -0.27128282, -0.34413835, 0.29697102, 0.12395442, 0.14231798,\n", - " -0.10790487], dtype=float32),\n", - " 'Some-college': array([-0.2126067 , 0.04664122, -0.15191978, -0.10957965, -0.12881616,\n", - " -0.04466751, 0.25502843, 0.32889867, 0.0168101 , 0.20086999,\n", - " -0.21912436, -0.00544369, 0.03351 , -0.17859232, 0.1382413 ,\n", - " 0.26502082], dtype=float32),\n", - " '10th': array([-0.3121446 , 0.19805874, -0.03366002, 0.1288065 , 0.26396075,\n", - " -0.05587888, 0.22792356, -0.06681106, 0.12476017, 0.37026265,\n", - " 0.03204104, -0.09612755, 0.0324997 , -0.08246089, 0.04117873,\n", - " 0.1853117 ], dtype=float32),\n", - " 'Prof-school': array([-0.4429325 , -0.12834997, 0.3658504 , 0.48140833, 0.11574885,\n", - " -0.192547 , 0.1586941 , -0.2919336 , 0.1567621 , 0.29656097,\n", - " 0.18974394, 0.06253866, 0.16234514, -0.08963383, -0.08024175,\n", - " 0.54286146], dtype=float32),\n", - " '7th-8th': array([ 0.54942334, 0.37394103, -0.03598195, -0.05772773, -0.28254417,\n", - " 0.54470855, -0.6513119 , -0.13811558, -0.11478714, 0.06010893,\n", - " -0.2462508 , 0.1755247 , 0.10117105, 0.36358032, -0.09656113,\n", - " 0.34954002], dtype=float32),\n", - " 'Bachelors': array([ 0.06564163, -0.23048915, -0.12470629, -0.02602417, 0.35001647,\n", - " -0.18802756, 0.10905975, -0.33273023, 0.01738172, 0.2478116 ,\n", - " 0.00981276, -0.18224423, 0.0950555 , 0.17849174, 0.17942917,\n", - " 0.31124604], dtype=float32),\n", - " 'Masters': array([ 0.13041618, -0.07283561, -0.34077218, 0.05142086, 0.08315329,\n", - " -0.12212724, 0.31239262, -0.20927685, -0.24285726, 0.06567737,\n", - " 0.03671836, -0.03405587, 0.01641322, 0.17043172, -0.38756114,\n", - " 0.30868122], dtype=float32),\n", - " 'Doctorate': array([-0.10755017, -0.03946237, -0.5153946 , 0.23642367, -0.4680825 ,\n", - " 0.2587171 , -0.1300325 , -0.05143512, -0.20121185, -0.02474 ,\n", - " -0.09320115, -0.07455952, 0.10833438, -0.02096028, -0.12492044,\n", - " 0.00582709], dtype=float32),\n", - " '5th-6th': array([-0.12893526, 0.27144003, 0.37272307, 0.3963532 , 0.34640747,\n", - " -0.33437288, 0.0193824 , -0.01519158, -0.42908698, 0.05110272,\n", - " 0.01151075, 0.15922028, -0.17880926, -0.36683136, -0.40467307,\n", - " -0.12017028], dtype=float32),\n", - " 'Assoc-voc': array([ 0.02241084, -0.07670853, -0.22828907, -0.12371975, -0.07486907,\n", - " -0.29233935, 0.31587106, 0.2165355 , 0.20171323, -0.15870345,\n", - " -0.1275358 , -0.21006238, -0.03274518, -0.14725143, -0.213672 ,\n", - " 0.30866137], dtype=float32),\n", - " '9th': array([ 0.1470835 , -0.0528347 , 0.24995384, -0.21315503, -0.24470845,\n", - " 0.819329 , 0.04469828, 0.09546001, 0.24664721, 0.3054443 ,\n", - " 0.4566717 , 0.14872263, 0.0116579 , 0.2515947 , 0.2023506 ,\n", - " -0.3379088 ], dtype=float32),\n", - " '12th': array([-0.01843497, 0.21602574, -0.35730916, -0.16129005, 0.34858495,\n", - " 0.07911005, -0.09155226, 0.25502652, -0.20713754, -0.2009355 ,\n", - " -0.18680803, 0.05695441, -0.20793928, -0.01325957, -0.28487244,\n", - " 0.26250076], dtype=float32),\n", - " '1st-4th': array([-0.5274408 , -0.17692605, -0.32478535, -0.15695599, 0.03235544,\n", - " -0.37266013, 0.35468644, 0.16074362, -0.36835802, 0.37510112,\n", - " 0.0420665 , -0.19070098, 0.33601463, -0.4323496 , -0.19171081,\n", - " -0.27081746], dtype=float32),\n", - " 'Preschool': array([ 0.07924446, 0.11405066, -0.17461444, -0.11104126, 0.45389435,\n", - " -0.06884138, -0.07859107, 0.30992216, -0.09668542, -0.03197801,\n", - " 0.25111035, 0.5209666 , 0.61060447, 0.03642207, 0.05149668,\n", - " 0.14839056], dtype=float32)}" + "{'11th': array([-0.3475832 , 0.34912273, -0.11974874, 0.14691196, -0.22545682,\n", + " -0.3613695 , -0.00136127, -0.0563265 , 0.3466888 , 0.11706785,\n", + " -0.01166581, -0.01369573, -0.17875178, 0.18713965, 0.2914308 ,\n", + " -0.198182 ], dtype=float32),\n", + " 'HS-grad': array([ 0.09942148, -0.33260158, 0.2164713 , -0.2940495 , 0.22636804,\n", + " 0.12042803, -0.07338171, 0.17187971, -0.12905738, 0.3129245 ,\n", + " -0.31488863, -0.17345233, 0.32477817, 0.00439972, 0.39258945,\n", + " -0.14481816], dtype=float32),\n", + " 'Assoc-acdm': array([-0.00751864, -0.1771137 , 0.06895561, -0.21083945, 0.23953192,\n", + " -0.6551445 , 0.01284237, -0.0050387 , -0.07738334, 0.00540992,\n", + " 0.0681937 , 0.05531053, -0.4259041 , -0.1871334 , -0.04381247,\n", + " 0.32671115], dtype=float32),\n", + " 'Some-college': array([ 0.01929094, 0.10994322, 0.36765632, -0.23809849, 0.10644584,\n", + " -0.19297272, -0.39444843, 0.32810718, -0.05060181, 0.4375799 ,\n", + " 0.34009618, -0.30499312, 0.06079052, -0.36158556, 0.16431686,\n", + " -0.02064201], dtype=float32),\n", + " '10th': array([ 0.32986915, 0.20145807, -0.46201912, 0.15131666, 0.39709982,\n", + " 0.69238126, 0.20381889, 0.10686771, -0.00311412, 0.40032774,\n", + " -0.25356117, 0.05119215, -0.5510974 , -0.3487673 , -0.05308707,\n", + " -0.15400933], dtype=float32),\n", + " 'Prof-school': array([-0.08333081, -0.1471433 , 0.0884981 , 0.7094311 , -0.22927387,\n", + " -0.07990997, 0.09308612, 0.13682584, 0.31950092, 0.0993206 ,\n", + " 0.31872186, -0.05731025, -0.02362061, 0.2931348 , 0.22886205,\n", + " 0.07881528], dtype=float32),\n", + " '7th-8th': array([-0.17651281, 0.14389877, -0.51749426, -0.36477914, -0.24834834,\n", + " -0.35161367, -0.38828874, -0.36244267, 0.16288954, -0.2656618 ,\n", + " -0.42065242, -0.16790003, -0.04955713, -0.4936896 , -0.07479241,\n", + " -0.15467522], dtype=float32),\n", + " 'Bachelors': array([-1.7109926e-01, -1.5259677e-01, 2.1966804e-02, 2.4115700e-01,\n", + " -6.3872856e-01, -6.5369323e-02, -2.5777605e-01, 1.3853328e-01,\n", + " -1.3078525e-04, 3.5386881e-01, 3.6984026e-02, 1.3007362e-01,\n", + " -3.4332672e-01, 1.4918861e-01, -2.7776187e-02, 1.8584514e-02],\n", + " dtype=float32),\n", + " 'Masters': array([-0.0607332 , 0.03891989, -0.12000098, -0.26994392, 0.1360479 ,\n", + " -0.12739065, -0.42029268, 0.4281459 , -0.2187682 , -0.10016173,\n", + " 0.21315324, 0.06292748, -0.17620797, 0.06142575, -0.16202934,\n", + " 0.07813183], dtype=float32),\n", + " 'Doctorate': array([ 0.07975567, -0.2995707 , -0.1297475 , -0.23506498, -0.07601811,\n", + " 0.21119696, -0.4014182 , 0.3409825 , 0.00557449, 0.17662002,\n", + " 0.01124496, 0.01987186, 0.00463357, 0.05345817, 0.28748044,\n", + " -0.24112043], dtype=float32),\n", + " '5th-6th': array([ 0.5512265 , -0.1908678 , -0.27131537, 0.36982986, 0.26176104,\n", + " 0.36645773, -0.23311335, -0.12837252, 0.24260557, -0.01326179,\n", + " 0.14636081, -0.0393713 , 0.12896451, -0.14971113, 0.01964791,\n", + " -0.04153565], dtype=float32),\n", + " 'Assoc-voc': array([ 0.00294167, -0.39827642, -0.02495229, 0.13957082, -0.13182898,\n", + " -0.27178332, 0.14512709, -0.29180354, -0.39801288, -0.15011302,\n", + " -0.19905967, -0.19827461, 0.1912367 , 0.1386391 , 0.13930447,\n", + " 0.05284905], dtype=float32),\n", + " '9th': array([ 0.10768763, -0.06806335, -0.18458003, 0.07836349, 0.3678258 ,\n", + " -0.03671409, -0.02125892, -0.22644126, 0.24890126, -0.01134706,\n", + " -0.35545322, -0.26837015, -0.22845276, -0.00784048, -0.01379843,\n", + " 0.07515417], dtype=float32),\n", + " '12th': array([ 0.3593276 , 0.4534212 , -0.17996144, 0.288639 , 0.03528969,\n", + " 0.01434682, 0.33964154, -0.19378136, -0.09871213, 0.073057 ,\n", + " -0.09627059, -0.1055373 , -0.24785268, 0.4939406 , 0.11959701,\n", + " -0.10218817], dtype=float32),\n", + " '1st-4th': array([ 0.1668331 , 0.12820388, -0.19806872, -0.07793977, 0.10481353,\n", + " -0.2746754 , 0.33344626, 0.09796116, 0.4184359 , -0.06698637,\n", + " -0.49304193, -0.44370967, -0.07711301, -0.24175553, 0.42256072,\n", + " 0.3274595 ], dtype=float32),\n", + " 'Preschool': array([ 0.5568573 , -0.02487507, -0.21234682, -0.19250521, -0.26240364,\n", + " 0.11080477, 0.41791028, -0.10821233, -0.04813243, -0.19189784,\n", + " -0.03009004, 0.28244784, 0.44653463, 0.2691065 , 0.4888137 ,\n", + " -0.01453342], dtype=float32)}" ] }, "execution_count": 18, diff --git a/examples/05_Regression_with_Images_and_Text.ipynb b/examples/05_Regression_with_Images_and_Text.ipynb index b3dddd32c68284b6d0f575aaca742feca98ac088..e6664b5000f217e2fbdbe7ca1c3acb4f0c66ef45 100644 --- a/examples/05_Regression_with_Images_and_Text.ipynb +++ b/examples/05_Regression_with_Images_and_Text.ipynb @@ -1069,7 +1069,7 @@ "name": "stderr", "output_type": "stream", "text": [ - " 4%|▍ | 40/1001 [00:00<00:02, 392.70it/s]" + " 4%|▍ | 41/1001 [00:00<00:02, 402.45it/s]" ] }, { @@ -1083,7 +1083,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 1001/1001 [00:02<00:00, 382.63it/s]\n" + "100%|██████████| 1001/1001 [00:02<00:00, 411.97it/s]\n" ] }, { @@ -1173,15 +1173,15 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 25/25 [02:05<00:00, 5.03s/it, loss=107]\n", - "valid: 100%|██████████| 7/7 [00:15<00:00, 2.22s/it, loss=129]\n" + "epoch 1: 100%|██████████| 25/25 [02:13<00:00, 5.35s/it, loss=115]\n", + "valid: 100%|██████████| 7/7 [00:15<00:00, 2.20s/it, loss=108] \n" ] } ], @@ -1201,7 +1201,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -1238,7 +1238,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -1260,7 +1260,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "metadata": { "scrolled": false }, @@ -1273,19 +1273,21 @@ " (wide_linear): Embedding(357, 1, padding_idx=0)\n", " )\n", " (deeptabular): TabMlp(\n", - " (embed_layers): ModuleDict(\n", - " (emb_layer_accommodates_catg): Embedding(4, 16, padding_idx=0)\n", - " (emb_layer_bathrooms_catg): Embedding(4, 16, padding_idx=0)\n", - " (emb_layer_bedrooms_catg): Embedding(5, 16, padding_idx=0)\n", - " (emb_layer_beds_catg): Embedding(5, 16, padding_idx=0)\n", - " (emb_layer_cancellation_policy): Embedding(6, 16, padding_idx=0)\n", - " (emb_layer_guests_included_catg): Embedding(4, 16, padding_idx=0)\n", - " (emb_layer_host_listings_count_catg): Embedding(5, 16, padding_idx=0)\n", - " (emb_layer_minimum_nights_catg): Embedding(4, 16, padding_idx=0)\n", - " (emb_layer_neighbourhood_cleansed): Embedding(33, 64, padding_idx=0)\n", + " (cat_embed_and_cont): CatEmbeddingsAndCont(\n", + " (embed_layers): ModuleDict(\n", + " (emb_layer_accommodates_catg): Embedding(4, 16, padding_idx=0)\n", + " (emb_layer_bathrooms_catg): Embedding(4, 16, padding_idx=0)\n", + " (emb_layer_bedrooms_catg): Embedding(5, 16, padding_idx=0)\n", + " (emb_layer_beds_catg): Embedding(5, 16, padding_idx=0)\n", + " (emb_layer_cancellation_policy): Embedding(6, 16, padding_idx=0)\n", + " (emb_layer_guests_included_catg): Embedding(4, 16, padding_idx=0)\n", + " (emb_layer_host_listings_count_catg): Embedding(5, 16, padding_idx=0)\n", + " (emb_layer_minimum_nights_catg): Embedding(4, 16, padding_idx=0)\n", + " (emb_layer_neighbourhood_cleansed): Embedding(33, 64, padding_idx=0)\n", + " )\n", + " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", + " (cont_norm): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", - " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", - " (cont_norm): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (tab_mlp): MLP(\n", " (mlp): Sequential(\n", " (dense_layer_0): Sequential(\n", @@ -1423,7 +1425,7 @@ ")" ] }, - "execution_count": 16, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -1443,7 +1445,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -1457,7 +1459,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -1470,7 +1472,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -1483,7 +1485,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -1509,7 +1511,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -1535,15 +1537,15 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 25/25 [02:09<00:00, 5.18s/it, loss=106]\n", - "valid: 100%|██████████| 7/7 [00:16<00:00, 2.31s/it, loss=95.5]" + "epoch 1: 100%|██████████| 25/25 [02:08<00:00, 5.12s/it, loss=108]\n", + "valid: 100%|██████████| 7/7 [00:15<00:00, 2.19s/it, loss=92.9]" ] }, { @@ -1575,7 +1577,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -1603,7 +1605,7 @@ " 'lr_deephead_0': [0.001, 0.001]}" ] }, - "execution_count": 23, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } diff --git a/examples/06_FineTune_and_WarmUp_Model_Components.ipynb b/examples/06_FineTune_and_WarmUp_Model_Components.ipynb index 95624f67163890f4f0382db435063375f6fc8046..1395e3889d910a166a4f37d7056b16169859cc8b 100644 --- a/examples/06_FineTune_and_WarmUp_Model_Components.ipynb +++ b/examples/06_FineTune_and_WarmUp_Model_Components.ipynb @@ -306,10 +306,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 153/153 [00:04<00:00, 36.06it/s, loss=0.529, metrics={'acc': 0.7448}]\n", - "valid: 100%|██████████| 39/39 [00:00<00:00, 68.26it/s, loss=0.389, metrics={'acc': 0.8176}]\n", - "epoch 2: 100%|██████████| 153/153 [00:03<00:00, 39.18it/s, loss=0.401, metrics={'acc': 0.8122}]\n", - "valid: 100%|██████████| 39/39 [00:00<00:00, 116.68it/s, loss=0.368, metrics={'acc': 0.8272}]\n" + "epoch 1: 100%|██████████| 153/153 [00:03<00:00, 43.48it/s, loss=0.565, metrics={'acc': 0.7249}]\n", + "valid: 100%|██████████| 39/39 [00:00<00:00, 62.94it/s, loss=0.387, metrics={'acc': 0.8207}]\n", + "epoch 2: 100%|██████████| 153/153 [00:04<00:00, 30.88it/s, loss=0.389, metrics={'acc': 0.8195}]\n", + "valid: 100%|██████████| 39/39 [00:00<00:00, 92.51it/s, loss=0.372, metrics={'acc': 0.8261}] \n" ] } ], @@ -387,7 +387,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 3%|▎ | 5/191 [00:00<00:03, 47.72it/s, loss=0.794, metrics={'acc': 0.5348}]" + "epoch 1: 4%|▎ | 7/191 [00:00<00:02, 64.63it/s, loss=1.17, metrics={'acc': 0.418}] " ] }, { @@ -401,9 +401,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 191/191 [00:02<00:00, 67.54it/s, loss=0.504, metrics={'acc': 0.7554}]\n", - "epoch 2: 100%|██████████| 191/191 [00:02<00:00, 70.24it/s, loss=0.386, metrics={'acc': 0.79}] \n", - "epoch 1: 4%|▎ | 7/191 [00:00<00:03, 60.96it/s, loss=0.39, metrics={'acc': 0.7909}] " + "epoch 1: 100%|██████████| 191/191 [00:02<00:00, 68.57it/s, loss=0.584, metrics={'acc': 0.7155}]\n", + "epoch 2: 100%|██████████| 191/191 [00:03<00:00, 62.76it/s, loss=0.39, metrics={'acc': 0.7697}] \n", + "epoch 1: 3%|▎ | 6/191 [00:00<00:03, 56.94it/s, loss=0.403, metrics={'acc': 0.7705}]" ] }, { @@ -417,9 +417,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 191/191 [00:03<00:00, 62.41it/s, loss=0.369, metrics={'acc': 0.8028}]\n", - "epoch 2: 100%|██████████| 191/191 [00:03<00:00, 59.52it/s, loss=0.352, metrics={'acc': 0.8107}]\n", - "epoch 1: 3%|▎ | 5/191 [00:00<00:04, 43.10it/s, loss=0.363, metrics={'acc': 0.8418}]" + "epoch 1: 100%|██████████| 191/191 [00:03<00:00, 58.09it/s, loss=0.369, metrics={'acc': 0.7887}]\n", + "epoch 2: 100%|██████████| 191/191 [00:03<00:00, 50.37it/s, loss=0.353, metrics={'acc': 0.8003}]\n", + "epoch 1: 2%|▏ | 4/191 [00:00<00:05, 36.39it/s, loss=0.399, metrics={'acc': 0.8298}]" ] }, { @@ -433,8 +433,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 191/191 [00:04<00:00, 39.91it/s, loss=0.352, metrics={'acc': 0.8378}]\n", - "epoch 2: 100%|██████████| 191/191 [00:04<00:00, 43.80it/s, loss=0.344, metrics={'acc': 0.8419}]\n" + "epoch 1: 100%|██████████| 191/191 [00:05<00:00, 33.73it/s, loss=0.355, metrics={'acc': 0.8377}]\n", + "epoch 2: 100%|██████████| 191/191 [00:05<00:00, 36.01it/s, loss=0.347, metrics={'acc': 0.8396}]\n" ] } ], @@ -488,7 +488,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 3%|▎ | 6/172 [00:00<00:02, 58.53it/s, loss=0.988, metrics={'acc': 0.4435}]" + "epoch 1: 3%|▎ | 6/172 [00:00<00:03, 52.52it/s, loss=0.628, metrics={'acc': 0.6977}]" ] }, { @@ -502,9 +502,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 172/172 [00:02<00:00, 73.06it/s, loss=0.54, metrics={'acc': 0.7276}] \n", - "epoch 2: 100%|██████████| 172/172 [00:02<00:00, 75.57it/s, loss=0.389, metrics={'acc': 0.7736}]\n", - "epoch 1: 3%|▎ | 6/172 [00:00<00:02, 55.48it/s, loss=0.582, metrics={'acc': 0.7728}]" + "epoch 1: 100%|██████████| 172/172 [00:02<00:00, 68.16it/s, loss=0.475, metrics={'acc': 0.7799}]\n", + "epoch 2: 100%|██████████| 172/172 [00:02<00:00, 68.59it/s, loss=0.387, metrics={'acc': 0.8021}]\n", + "epoch 1: 2%|▏ | 4/172 [00:00<00:04, 34.15it/s, loss=0.62, metrics={'acc': 0.8009}] " ] }, { @@ -518,9 +518,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 172/172 [00:02<00:00, 58.52it/s, loss=0.392, metrics={'acc': 0.7881}]\n", - "epoch 2: 100%|██████████| 172/172 [00:02<00:00, 58.26it/s, loss=0.353, metrics={'acc': 0.8}] \n", - "epoch 1: 2%|▏ | 4/172 [00:00<00:04, 38.87it/s, loss=0.337, metrics={'acc': 0.8589}]" + "epoch 1: 100%|██████████| 172/172 [00:04<00:00, 38.78it/s, loss=0.392, metrics={'acc': 0.8075}]\n", + "epoch 2: 100%|██████████| 172/172 [00:03<00:00, 49.88it/s, loss=0.354, metrics={'acc': 0.8145}]\n", + "epoch 1: 2%|▏ | 4/172 [00:00<00:04, 38.71it/s, loss=0.412, metrics={'acc': 0.8326}]" ] }, { @@ -534,10 +534,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 172/172 [00:04<00:00, 42.81it/s, loss=0.355, metrics={'acc': 0.8366}]\n", - "valid: 100%|██████████| 20/20 [00:00<00:00, 89.21it/s, loss=0.35, metrics={'acc': 0.8356}] \n", - "epoch 2: 100%|██████████| 172/172 [00:04<00:00, 41.35it/s, loss=0.346, metrics={'acc': 0.8381}]\n", - "valid: 100%|██████████| 20/20 [00:00<00:00, 87.63it/s, loss=0.349, metrics={'acc': 0.8373}]\n" + "epoch 1: 100%|██████████| 172/172 [00:04<00:00, 38.84it/s, loss=0.354, metrics={'acc': 0.8372}]\n", + "valid: 100%|██████████| 20/20 [00:00<00:00, 81.14it/s, loss=0.348, metrics={'acc': 0.8399}]\n", + "epoch 2: 100%|██████████| 172/172 [00:04<00:00, 37.92it/s, loss=0.345, metrics={'acc': 0.8397}]\n", + "valid: 100%|██████████| 20/20 [00:00<00:00, 46.33it/s, loss=0.347, metrics={'acc': 0.8409}]\n" ] } ], @@ -629,15 +629,17 @@ " )\n", " (deeptabular): Sequential(\n", " (0): TabResnet(\n", - " (embed_layers): ModuleDict(\n", - " (emb_layer_education): Embedding(17, 16, padding_idx=0)\n", - " (emb_layer_native_country): Embedding(43, 16, padding_idx=0)\n", - " (emb_layer_occupation): Embedding(16, 16, padding_idx=0)\n", - " (emb_layer_relationship): Embedding(7, 8, padding_idx=0)\n", - " (emb_layer_workclass): Embedding(10, 16, padding_idx=0)\n", + " (cat_embed_and_cont): CatEmbeddingsAndCont(\n", + " (embed_layers): ModuleDict(\n", + " (emb_layer_education): Embedding(17, 16, padding_idx=0)\n", + " (emb_layer_native_country): Embedding(43, 16, padding_idx=0)\n", + " (emb_layer_occupation): Embedding(16, 16, padding_idx=0)\n", + " (emb_layer_relationship): Embedding(7, 8, padding_idx=0)\n", + " (emb_layer_workclass): Embedding(10, 16, padding_idx=0)\n", + " )\n", + " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", + " (cont_norm): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", - " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", - " (cont_norm): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (tab_resnet_blks): DenseResnet(\n", " (dense_resnet): Sequential(\n", " (lin1): Linear(in_features=74, out_features=128, bias=True)\n", @@ -708,10 +710,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 172/172 [00:05<00:00, 29.00it/s, loss=0.453, metrics={'acc': 0.7787}]\n", - "valid: 100%|██████████| 20/20 [00:00<00:00, 90.03it/s, loss=0.363, metrics={'acc': 0.8282}]\n", - "epoch 2: 100%|██████████| 172/172 [00:05<00:00, 32.24it/s, loss=0.371, metrics={'acc': 0.8262}]\n", - "valid: 100%|██████████| 20/20 [00:00<00:00, 88.22it/s, loss=0.351, metrics={'acc': 0.8356}]\n" + "epoch 1: 100%|██████████| 172/172 [00:07<00:00, 23.55it/s, loss=0.411, metrics={'acc': 0.8033}]\n", + "valid: 100%|██████████| 20/20 [00:00<00:00, 71.22it/s, loss=0.364, metrics={'acc': 0.8287}]\n", + "epoch 2: 100%|██████████| 172/172 [00:06<00:00, 25.12it/s, loss=0.369, metrics={'acc': 0.827}] \n", + "valid: 100%|██████████| 20/20 [00:00<00:00, 78.16it/s, loss=0.355, metrics={'acc': 0.8342}]\n" ] } ], @@ -788,7 +790,7 @@ "outputs": [], "source": [ "tab_deep_layers = list(\n", - " list(list(list(model_3.deeptabular.children())[0].children())[3].children())[\n", + " list(list(list(model_3.deeptabular.children())[0].children())[1].children())[\n", " 0\n", " ].children()\n", ")[::-1][:2]" @@ -865,14 +867,14 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 5%|▍ | 8/172 [00:00<00:02, 68.51it/s, loss=0.767, metrics={'acc': 0.5605}]" + "epoch 1: 5%|▍ | 8/172 [00:00<00:02, 71.14it/s, loss=0.719, metrics={'acc': 0.6278}]" ] }, { @@ -886,9 +888,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 172/172 [00:02<00:00, 75.72it/s, loss=0.489, metrics={'acc': 0.7523}]\n", - "epoch 2: 100%|██████████| 172/172 [00:02<00:00, 64.95it/s, loss=0.383, metrics={'acc': 0.7876}]\n", - "epoch 1: 2%|▏ | 3/172 [00:00<00:07, 22.26it/s, loss=0.402, metrics={'acc': 0.788}] " + "epoch 1: 100%|██████████| 172/172 [00:03<00:00, 56.88it/s, loss=0.496, metrics={'acc': 0.7596}]\n", + "epoch 2: 100%|██████████| 172/172 [00:02<00:00, 68.06it/s, loss=0.386, metrics={'acc': 0.7917}]\n", + "epoch 1: 2%|▏ | 4/172 [00:00<00:04, 38.40it/s, loss=0.435, metrics={'acc': 0.7915}]" ] }, { @@ -902,8 +904,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 172/172 [00:08<00:00, 20.71it/s, loss=0.385, metrics={'acc': 0.7986}]\n", - "epoch 1: 0%| | 0/172 [00:00" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "model_5.load_state_dict(torch.load(\"models_dir/model_5.pt\"))" ] @@ -1047,7 +1070,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ @@ -1056,9 +1079,46 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 36, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "epoch 1: 3%|▎ | 6/172 [00:00<00:03, 51.73it/s, loss=0.371, metrics={'acc': 0.8247}]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training deeptabular for 2 epochs\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "epoch 1: 100%|██████████| 172/172 [00:03<00:00, 48.47it/s, loss=0.367, metrics={'acc': 0.8287}]\n", + "epoch 2: 100%|██████████| 172/172 [00:03<00:00, 51.73it/s, loss=0.352, metrics={'acc': 0.833}] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fine-tuning finished\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "trainer_6.fit(\n", " X_wide=X_wide, \n", @@ -1076,7 +1136,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 37, "metadata": {}, "outputs": [], "source": [ diff --git a/examples/07_Custom_Components.ipynb b/examples/07_Custom_Components.ipynb index 1f6387d59cd1c83edaf12ff4267f25ae715190d3..eb71c4cbd8b2371186afa91dce7db8e71824bc1e 100644 --- a/examples/07_Custom_Components.ipynb +++ b/examples/07_Custom_Components.ipynb @@ -1011,7 +1011,7 @@ "name": "stderr", "output_type": "stream", "text": [ - " 4%|▍ | 42/1001 [00:00<00:02, 419.42it/s]" + " 3%|▎ | 29/1001 [00:00<00:03, 288.35it/s]" ] }, { @@ -1025,7 +1025,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 1001/1001 [00:02<00:00, 408.24it/s]\n" + "100%|██████████| 1001/1001 [00:03<00:00, 307.23it/s]\n" ] }, { @@ -1194,8 +1194,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 25/25 [02:03<00:00, 4.94s/it, loss=111]\n", - "valid: 100%|██████████| 7/7 [00:15<00:00, 2.17s/it, loss=94.8]\n" + "epoch 1: 100%|██████████| 25/25 [02:21<00:00, 5.65s/it, loss=120]\n", + "valid: 100%|██████████| 7/7 [00:15<00:00, 2.21s/it, loss=217]\n" ] } ], @@ -1735,16 +1735,16 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 611/611 [00:06<00:00, 90.27it/s, loss=0.455, metrics={'acc': 0.7866}]\n", - "valid: 100%|██████████| 153/153 [00:00<00:00, 155.16it/s, loss=0.367, metrics={'acc': 0.8318}]\n", - "epoch 2: 100%|██████████| 611/611 [00:06<00:00, 89.76it/s, loss=0.379, metrics={'acc': 0.8233}]\n", - "valid: 100%|██████████| 153/153 [00:01<00:00, 151.92it/s, loss=0.354, metrics={'acc': 0.8362}]\n", - "epoch 3: 100%|██████████| 611/611 [00:06<00:00, 89.63it/s, loss=0.365, metrics={'acc': 0.8296}]\n", - "valid: 100%|██████████| 153/153 [00:00<00:00, 154.56it/s, loss=0.35, metrics={'acc': 0.8383}] \n", - "epoch 4: 100%|██████████| 611/611 [00:06<00:00, 87.99it/s, loss=0.357, metrics={'acc': 0.8333}]\n", - "valid: 100%|██████████| 153/153 [00:00<00:00, 155.60it/s, loss=0.348, metrics={'acc': 0.8406}]\n", - "epoch 5: 100%|██████████| 611/611 [00:07<00:00, 87.18it/s, loss=0.354, metrics={'acc': 0.835}] \n", - "valid: 100%|██████████| 153/153 [00:01<00:00, 149.32it/s, loss=0.346, metrics={'acc': 0.8402}]\n" + "epoch 1: 100%|██████████| 611/611 [00:07<00:00, 83.88it/s, loss=0.402, metrics={'acc': 0.8093}]\n", + "valid: 100%|██████████| 153/153 [00:00<00:00, 153.65it/s, loss=0.359, metrics={'acc': 0.8365}]\n", + "epoch 2: 100%|██████████| 611/611 [00:06<00:00, 89.07it/s, loss=0.365, metrics={'acc': 0.8276}]\n", + "valid: 100%|██████████| 153/153 [00:00<00:00, 155.09it/s, loss=0.354, metrics={'acc': 0.8367}]\n", + "epoch 3: 100%|██████████| 611/611 [00:06<00:00, 89.26it/s, loss=0.357, metrics={'acc': 0.8338}]\n", + "valid: 100%|██████████| 153/153 [00:00<00:00, 154.81it/s, loss=0.351, metrics={'acc': 0.8371}]\n", + "epoch 4: 100%|██████████| 611/611 [00:06<00:00, 89.63it/s, loss=0.354, metrics={'acc': 0.8335}]\n", + "valid: 100%|██████████| 153/153 [00:01<00:00, 151.97it/s, loss=0.348, metrics={'acc': 0.8401}]\n", + "epoch 5: 100%|██████████| 611/611 [00:09<00:00, 66.64it/s, loss=0.351, metrics={'acc': 0.8349}]\n", + "valid: 100%|██████████| 153/153 [00:01<00:00, 127.27it/s, loss=0.347, metrics={'acc': 0.8406}]\n" ] } ], diff --git a/examples/08_save_and_load_model_and_artifacts.ipynb b/examples/08_save_and_load_model_and_artifacts.ipynb index 384f896e414219973fd181ad87a1107fe7033d3e..2b4605e494782e74f0a53685421540616a85ab25 100644 --- a/examples/08_save_and_load_model_and_artifacts.ipynb +++ b/examples/08_save_and_load_model_and_artifacts.ipynb @@ -488,31 +488,33 @@ "text/plain": [ "WideDeep(\n", " (wide): Wide(\n", - " (wide_linear): Embedding(773, 1, padding_idx=0)\n", + " (wide_linear): Embedding(779, 1, padding_idx=0)\n", " )\n", " (deeptabular): Sequential(\n", " (0): TabMlp(\n", - " (embed_layers): ModuleDict(\n", - " (emb_layer_age): Embedding(75, 18, padding_idx=0)\n", - " (emb_layer_capital_gain): Embedding(121, 23, padding_idx=0)\n", - " (emb_layer_capital_loss): Embedding(98, 21, padding_idx=0)\n", - " (emb_layer_education): Embedding(17, 8, padding_idx=0)\n", - " (emb_layer_educational_num): Embedding(17, 8, padding_idx=0)\n", - " (emb_layer_gender): Embedding(3, 2, padding_idx=0)\n", - " (emb_layer_hours_per_week): Embedding(96, 20, padding_idx=0)\n", - " (emb_layer_marital_status): Embedding(8, 5, padding_idx=0)\n", - " (emb_layer_native_country): Embedding(43, 13, padding_idx=0)\n", - " (emb_layer_occupation): Embedding(16, 7, padding_idx=0)\n", - " (emb_layer_race): Embedding(6, 4, padding_idx=0)\n", - " (emb_layer_relationship): Embedding(7, 4, padding_idx=0)\n", - " (emb_layer_workclass): Embedding(10, 5, padding_idx=0)\n", + " (cat_embed_and_cont): CatEmbeddingsAndCont(\n", + " (embed_layers): ModuleDict(\n", + " (emb_layer_age): Embedding(75, 18, padding_idx=0)\n", + " (emb_layer_capital_gain): Embedding(119, 23, padding_idx=0)\n", + " (emb_layer_capital_loss): Embedding(98, 21, padding_idx=0)\n", + " (emb_layer_education): Embedding(17, 8, padding_idx=0)\n", + " (emb_layer_educational_num): Embedding(17, 8, padding_idx=0)\n", + " (emb_layer_gender): Embedding(3, 2, padding_idx=0)\n", + " (emb_layer_hours_per_week): Embedding(97, 21, padding_idx=0)\n", + " (emb_layer_marital_status): Embedding(8, 5, padding_idx=0)\n", + " (emb_layer_native_country): Embedding(43, 13, padding_idx=0)\n", + " (emb_layer_occupation): Embedding(16, 7, padding_idx=0)\n", + " (emb_layer_race): Embedding(6, 4, padding_idx=0)\n", + " (emb_layer_relationship): Embedding(7, 4, padding_idx=0)\n", + " (emb_layer_workclass): Embedding(10, 5, padding_idx=0)\n", + " )\n", + " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", " )\n", - " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", " (tab_mlp): MLP(\n", " (mlp): Sequential(\n", " (dense_layer_0): Sequential(\n", " (0): Dropout(p=0.1, inplace=False)\n", - " (1): Linear(in_features=138, out_features=200, bias=True)\n", + " (1): Linear(in_features=139, out_features=200, bias=True)\n", " (2): ReLU(inplace=True)\n", " )\n", " (dense_layer_1): Sequential(\n", @@ -546,9 +548,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 153/153 [00:04<00:00, 31.06it/s, loss=0.479, metrics={'acc': 0.7839}]\n", - "valid: 100%|██████████| 20/20 [00:00<00:00, 48.70it/s, loss=0.348, metrics={'acc': 0.8444}]\n", - "epoch 2: 3%|▎ | 4/153 [00:00<00:04, 32.06it/s, loss=0.365, metrics={'acc': 0.8385}]" + "epoch 1: 100%|██████████| 153/153 [00:04<00:00, 31.17it/s, loss=0.444, metrics={'acc': 0.7952}]\n", + "valid: 100%|██████████| 20/20 [00:00<00:00, 46.39it/s, loss=0.357, metrics={'acc': 0.838}] \n", + "epoch 2: 2%|▏ | 3/153 [00:00<00:05, 27.83it/s, loss=0.38, metrics={'acc': 0.8305}] " ] }, { @@ -556,15 +558,15 @@ "output_type": "stream", "text": [ "\n", - "Epoch 00001: val_loss improved from inf to 0.34800, saving model to tmp_dir/adult_tabmlp_model_1.p\n" + "Epoch 00001: val_loss improved from inf to 0.35746, saving model to tmp_dir/adult_tabmlp_model_1.p\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "epoch 2: 100%|██████████| 153/153 [00:04<00:00, 32.33it/s, loss=0.354, metrics={'acc': 0.8379}]\n", - "valid: 100%|██████████| 20/20 [00:00<00:00, 92.91it/s, loss=0.322, metrics={'acc': 0.8511}] " + "epoch 2: 100%|██████████| 153/153 [00:04<00:00, 31.57it/s, loss=0.352, metrics={'acc': 0.8395}]\n", + "valid: 100%|██████████| 20/20 [00:00<00:00, 88.95it/s, loss=0.325, metrics={'acc': 0.8557}]" ] }, { @@ -572,7 +574,7 @@ "output_type": "stream", "text": [ "\n", - "Epoch 00002: val_loss improved from 0.34800 to 0.32204, saving model to tmp_dir/adult_tabmlp_model_2.p\n", + "Epoch 00002: val_loss improved from 0.35746 to 0.32545, saving model to tmp_dir/adult_tabmlp_model_2.p\n", "Model weights restored to best epoch: 2\n" ] }, @@ -844,120 +846,120 @@ " \n", " \n", " \n", - " 32823\n", - " 33\n", + " 3428\n", + " 40\n", " Private\n", - " 201988\n", - " Masters\n", - " 14\n", + " 144778\n", + " HS-grad\n", + " 9\n", " Married-civ-spouse\n", - " Prof-specialty\n", + " Exec-managerial\n", " Husband\n", " White\n", " Male\n", " 0\n", " 0\n", - " 45\n", + " 50\n", " United-States\n", - " 0\n", + " 1\n", " \n", " \n", - " 40713\n", - " 31\n", - " Private\n", - " 231826\n", + " 8234\n", + " 38\n", + " ?\n", + " 54953\n", " HS-grad\n", " 9\n", - " Married-civ-spouse\n", - " Other-service\n", - " Husband\n", + " Divorced\n", + " ?\n", + " Not-in-family\n", " White\n", " Male\n", " 0\n", " 0\n", - " 52\n", - " Mexico\n", + " 30\n", + " United-States\n", " 0\n", " \n", " \n", - " 16020\n", - " 38\n", - " Private\n", - " 24126\n", - " Some-college\n", - " 10\n", - " Divorced\n", - " Exec-managerial\n", - " Not-in-family\n", + " 1129\n", + " 28\n", + " Local-gov\n", + " 134771\n", + " Bachelors\n", + " 13\n", + " Never-married\n", + " Prof-specialty\n", + " Own-child\n", " White\n", " Female\n", " 0\n", " 0\n", - " 40\n", + " 55\n", " United-States\n", " 0\n", " \n", " \n", - " 32766\n", - " 38\n", - " State-gov\n", - " 312528\n", - " Bachelors\n", - " 13\n", - " Married-civ-spouse\n", - " Exec-managerial\n", - " Husband\n", + " 11866\n", + " 47\n", + " Private\n", + " 189143\n", + " HS-grad\n", + " 9\n", + " Divorced\n", + " Farming-fishing\n", + " Not-in-family\n", " White\n", " Male\n", " 0\n", " 0\n", - " 37\n", + " 40\n", " United-States\n", " 0\n", " \n", " \n", - " 9713\n", - " 40\n", - " Self-emp-not-inc\n", - " 121012\n", - " Prof-school\n", - " 15\n", + " 39544\n", + " 27\n", + " Private\n", + " 224849\n", + " HS-grad\n", + " 9\n", " Married-civ-spouse\n", - " Prof-specialty\n", + " Craft-repair\n", " Husband\n", " White\n", " Male\n", " 0\n", - " 1977\n", - " 50\n", + " 0\n", + " 40\n", " United-States\n", - " 1\n", + " 0\n", " \n", " \n", "\n", "" ], "text/plain": [ - " age workclass fnlwgt education educational_num \\\n", - "32823 33 Private 201988 Masters 14 \n", - "40713 31 Private 231826 HS-grad 9 \n", - "16020 38 Private 24126 Some-college 10 \n", - "32766 38 State-gov 312528 Bachelors 13 \n", - "9713 40 Self-emp-not-inc 121012 Prof-school 15 \n", + " age workclass fnlwgt education educational_num marital_status \\\n", + "3428 40 Private 144778 HS-grad 9 Married-civ-spouse \n", + "8234 38 ? 54953 HS-grad 9 Divorced \n", + "1129 28 Local-gov 134771 Bachelors 13 Never-married \n", + "11866 47 Private 189143 HS-grad 9 Divorced \n", + "39544 27 Private 224849 HS-grad 9 Married-civ-spouse \n", "\n", - " marital_status occupation relationship race gender \\\n", - "32823 Married-civ-spouse Prof-specialty Husband White Male \n", - "40713 Married-civ-spouse Other-service Husband White Male \n", - "16020 Divorced Exec-managerial Not-in-family White Female \n", - "32766 Married-civ-spouse Exec-managerial Husband White Male \n", - "9713 Married-civ-spouse Prof-specialty Husband White Male \n", + " occupation relationship race gender capital_gain \\\n", + "3428 Exec-managerial Husband White Male 0 \n", + "8234 ? Not-in-family White Male 0 \n", + "1129 Prof-specialty Own-child White Female 0 \n", + "11866 Farming-fishing Not-in-family White Male 0 \n", + "39544 Craft-repair Husband White Male 0 \n", "\n", - " capital_gain capital_loss hours_per_week native_country target \n", - "32823 0 0 45 United-States 0 \n", - "40713 0 0 52 Mexico 0 \n", - "16020 0 0 40 United-States 0 \n", - "32766 0 0 37 United-States 0 \n", - "9713 0 1977 50 United-States 1 " + " capital_loss hours_per_week native_country target \n", + "3428 0 50 United-States 1 \n", + "8234 0 30 United-States 0 \n", + "1129 0 55 United-States 0 \n", + "11866 0 40 United-States 0 \n", + "39544 0 40 United-States 0 " ] }, "execution_count": 23, @@ -1055,7 +1057,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "predict: 100%|██████████| 20/20 [00:00<00:00, 86.04it/s]\n" + "predict: 100%|██████████| 20/20 [00:00<00:00, 91.44it/s]\n" ] } ], @@ -1080,7 +1082,7 @@ { "data": { "text/plain": [ - "0.8554759467758444" + "0.8511770726714432" ] }, "execution_count": 32, diff --git a/examples/09_Custom_DataLoader_Imbalanced_dataset.ipynb b/examples/09_Custom_DataLoader_Imbalanced_dataset.ipynb index 39661b696efc843217cd2ae1ba18edb85d640eac..1203e0b920958a8de2a4c1a075da3d26e05a4318 100644 --- a/examples/09_Custom_DataLoader_Imbalanced_dataset.ipynb +++ b/examples/09_Custom_DataLoader_Imbalanced_dataset.ipynb @@ -752,7 +752,9 @@ "WideDeep(\n", " (deeptabular): Sequential(\n", " (0): TabMlp(\n", - " (cont_norm): BatchNorm1d(74, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (cat_embed_and_cont): CatEmbeddingsAndCont(\n", + " (cont_norm): BatchNorm1d(74, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", " (tab_mlp): MLP(\n", " (mlp): Sequential(\n", " (dense_layer_0): Sequential(\n", @@ -861,12 +863,12 @@ "Consider using one of the following signatures instead:\n", "\tnonzero(Tensor input, *, bool as_tuple) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:766.)\n", " meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu()\n", - "epoch 1: 100%|██████████| 208/208 [00:02<00:00, 76.68it/s, loss=0.225, metrics={'Accuracy': [0.927, 0.8861], 'Precision': 0.9064, 'Recall': [0.927, 0.8861], 'F1': [0.9075, 0.9052]}] \n", - "valid: 100%|██████████| 292/292 [00:02<00:00, 107.00it/s, loss=0.104, metrics={'Accuracy': [0.9626, 0.875], 'Precision': 0.9618, 'Recall': [0.9626, 0.875], 'F1': [0.9804, 0.2886]}] \n", - "epoch 2: 100%|██████████| 208/208 [00:02<00:00, 84.52it/s, loss=0.152, metrics={'Accuracy': [0.9471, 0.9298], 'Precision': 0.9384, 'Recall': [0.9471, 0.9298], 'F1': [0.9384, 0.9383]}]\n", - "valid: 100%|██████████| 292/292 [00:02<00:00, 107.06it/s, loss=0.0915, metrics={'Accuracy': [0.968, 0.8906], 'Precision': 0.9673, 'Recall': [0.968, 0.8906], 'F1': [0.9833, 0.3258]}] \n", - "epoch 3: 100%|██████████| 208/208 [00:02<00:00, 81.70it/s, loss=0.134, metrics={'Accuracy': [0.949, 0.9407], 'Precision': 0.9448, 'Recall': [0.949, 0.9407], 'F1': [0.9446, 0.9451]}] \n", - "valid: 100%|██████████| 292/292 [00:02<00:00, 104.52it/s, loss=0.0847, metrics={'Accuracy': [0.9679, 0.8828], 'Precision': 0.9672, 'Recall': [0.9679, 0.8828], 'F1': [0.9832, 0.3229]}]\n" + "epoch 1: 100%|██████████| 208/208 [00:02<00:00, 73.29it/s, loss=0.232, metrics={'Accuracy': [0.9226, 0.89], 'Precision': 0.9061, 'Recall': [0.9226, 0.89], 'F1': [0.9065, 0.9057]}] \n", + "valid: 100%|██████████| 292/292 [00:02<00:00, 107.70it/s, loss=0.0598, metrics={'Accuracy': [0.981, 0.8672], 'Precision': 0.98, 'Recall': [0.981, 0.8672], 'F1': [0.9898, 0.4341]}] \n", + "epoch 2: 100%|██████████| 208/208 [00:02<00:00, 85.53it/s, loss=0.161, metrics={'Accuracy': [0.9417, 0.9262], 'Precision': 0.9339, 'Recall': [0.9417, 0.9262], 'F1': [0.9344, 0.9335]}]\n", + "valid: 100%|██████████| 292/292 [00:02<00:00, 102.48it/s, loss=0.123, metrics={'Accuracy': [0.9549, 0.9219], 'Precision': 0.9546, 'Recall': [0.9549, 0.9219], 'F1': [0.9766, 0.2647]}]\n", + "epoch 3: 100%|██████████| 208/208 [00:02<00:00, 80.85it/s, loss=0.139, metrics={'Accuracy': [0.948, 0.9413], 'Precision': 0.9446, 'Recall': [0.948, 0.9413], 'F1': [0.9451, 0.9442]}] \n", + "valid: 100%|██████████| 292/292 [00:02<00:00, 104.11it/s, loss=0.0722, metrics={'Accuracy': [0.9743, 0.8984], 'Precision': 0.9737, 'Recall': [0.9743, 0.8984], 'F1': [0.9865, 0.3766]}]\n" ] }, { @@ -930,42 +932,42 @@ " \n", " \n", " 0\n", - " 0.225108\n", - " [0.9270001649856567, 0.8861073851585388]\n", - " 0.906365\n", - " [0.9270001649856567, 0.8861073851585388]\n", - " [0.9074797034263611, 0.9052220582962036]\n", - " 0.103954\n", - " [0.962550163269043, 0.875]\n", - " 0.961784\n", - " [0.962550163269043, 0.875]\n", - " [0.9803645014762878, 0.28863346576690674]\n", + " 0.232111\n", + " [0.9225957989692688, 0.8899886012077332]\n", + " 0.906075\n", + " [0.9225957989692688, 0.8899886012077332]\n", + " [0.9064720273017883, 0.9056749939918518]\n", + " 0.059772\n", + " [0.9809635877609253, 0.8671875]\n", + " 0.979966\n", + " [0.9809635877609253, 0.8671875]\n", + " [0.989802360534668, 0.4341084957122803]\n", " \n", " \n", " 1\n", - " 0.152386\n", - " [0.9471125602722168, 0.9297876358032227]\n", - " 0.938380\n", - " [0.9471125602722168, 0.9297876358032227]\n", - " [0.9384452104568481, 0.9383144974708557]\n", - " 0.091541\n", - " [0.9680188298225403, 0.890625]\n", - " 0.967341\n", - " [0.9680188298225403, 0.890625]\n", - " [0.9832653999328613, 0.3257790505886078]\n", + " 0.160898\n", + " [0.9417101144790649, 0.9261900186538696]\n", + " 0.933944\n", + " [0.9417101144790649, 0.9261900186538696]\n", + " [0.9344058036804199, 0.9334757328033447]\n", + " 0.122636\n", + " [0.954935610294342, 0.921875]\n", + " 0.954648\n", + " [0.954935610294342, 0.921875]\n", + " [0.9766026139259338, 0.2647385895252228]\n", " \n", " \n", " 2\n", - " 0.134425\n", - " [0.9490371346473694, 0.9407152533531189]\n", - " 0.944841\n", - " [0.9490371346473694, 0.9407152533531189]\n", - " [0.9446273446083069, 0.9450528621673584]\n", - " 0.084717\n", - " [0.967949628829956, 0.8828125]\n", - " 0.967204\n", - " [0.967949628829956, 0.8828125]\n", - " [0.9831950664520264, 0.3229461908340454]\n", + " 0.138879\n", + " [0.9479646682739258, 0.9413018226623535]\n", + " 0.944648\n", + " [0.9479646682739258, 0.9413018226623535]\n", + " [0.945061206817627, 0.9442285299301147]\n", + " 0.072151\n", + " [0.9743181467056274, 0.8984375]\n", + " 0.973653\n", + " [0.9743181467056274, 0.8984375]\n", + " [0.9865424036979675, 0.3766234219074249]\n", " \n", " \n", "\n", @@ -973,29 +975,29 @@ ], "text/plain": [ " train_loss train_Accuracy train_Precision \\\n", - "0 0.225108 [0.9270001649856567, 0.8861073851585388] 0.906365 \n", - "1 0.152386 [0.9471125602722168, 0.9297876358032227] 0.938380 \n", - "2 0.134425 [0.9490371346473694, 0.9407152533531189] 0.944841 \n", + "0 0.232111 [0.9225957989692688, 0.8899886012077332] 0.906075 \n", + "1 0.160898 [0.9417101144790649, 0.9261900186538696] 0.933944 \n", + "2 0.138879 [0.9479646682739258, 0.9413018226623535] 0.944648 \n", "\n", " train_Recall \\\n", - "0 [0.9270001649856567, 0.8861073851585388] \n", - "1 [0.9471125602722168, 0.9297876358032227] \n", - "2 [0.9490371346473694, 0.9407152533531189] \n", + "0 [0.9225957989692688, 0.8899886012077332] \n", + "1 [0.9417101144790649, 0.9261900186538696] \n", + "2 [0.9479646682739258, 0.9413018226623535] \n", "\n", " train_F1 val_loss \\\n", - "0 [0.9074797034263611, 0.9052220582962036] 0.103954 \n", - "1 [0.9384452104568481, 0.9383144974708557] 0.091541 \n", - "2 [0.9446273446083069, 0.9450528621673584] 0.084717 \n", + "0 [0.9064720273017883, 0.9056749939918518] 0.059772 \n", + "1 [0.9344058036804199, 0.9334757328033447] 0.122636 \n", + "2 [0.945061206817627, 0.9442285299301147] 0.072151 \n", "\n", - " val_Accuracy val_Precision \\\n", - "0 [0.962550163269043, 0.875] 0.961784 \n", - "1 [0.9680188298225403, 0.890625] 0.967341 \n", - "2 [0.967949628829956, 0.8828125] 0.967204 \n", + " val_Accuracy val_Precision \\\n", + "0 [0.9809635877609253, 0.8671875] 0.979966 \n", + "1 [0.954935610294342, 0.921875] 0.954648 \n", + "2 [0.9743181467056274, 0.8984375] 0.973653 \n", "\n", - " val_Recall val_F1 \n", - "0 [0.962550163269043, 0.875] [0.9803645014762878, 0.28863346576690674] \n", - "1 [0.9680188298225403, 0.890625] [0.9832653999328613, 0.3257790505886078] \n", - "2 [0.967949628829956, 0.8828125] [0.9831950664520264, 0.3229461908340454] " + " val_Recall val_F1 \n", + "0 [0.9809635877609253, 0.8671875] [0.989802360534668, 0.4341084957122803] \n", + "1 [0.954935610294342, 0.921875] [0.9766026139259338, 0.2647385895252228] \n", + "2 [0.9743181467056274, 0.8984375] [0.9865424036979675, 0.3766234219074249] " ] }, "execution_count": 14, @@ -1016,7 +1018,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "predict: 100%|██████████| 292/292 [00:00<00:00, 294.65it/s]\n" + "predict: 100%|██████████| 292/292 [00:01<00:00, 286.78it/s]\n" ] }, { @@ -1025,15 +1027,15 @@ "text": [ " precision recall f1-score support\n", "\n", - " 0 1.00 0.97 0.99 14446\n", - " 1 0.23 0.93 0.36 130\n", + " 0 1.00 0.98 0.99 14446\n", + " 1 0.27 0.93 0.41 130\n", "\n", - " accuracy 0.97 14576\n", - " macro avg 0.61 0.95 0.67 14576\n", - "weighted avg 0.99 0.97 0.98 14576\n", + " accuracy 0.98 14576\n", + " macro avg 0.63 0.95 0.70 14576\n", + "weighted avg 0.99 0.98 0.98 14576\n", "\n", "Actual predicted values:\n", - "(array([0, 1]), array([14039, 537]))\n" + "(array([0, 1]), array([14122, 454]))\n" ] } ], diff --git a/examples/10_The_Transformer_Family.ipynb b/examples/10_The_Transformer_Family.ipynb index 245b47013acf9442aba28fba2a6974a064dc0045..c6f922b9ce14a21a405c397ff1405cbd0e4c0518 100644 --- a/examples/10_The_Transformer_Family.ipynb +++ b/examples/10_The_Transformer_Family.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "At the moment there are 3 transformer-based algorithms available. \n", + "At the moment there are 5 transformer-based algorithms available. \n", "\n", "Here are examples of how to use them\n", "\n", @@ -34,7 +34,7 @@ "\n", "from pytorch_widedeep.preprocessing import TabPreprocessor\n", "from pytorch_widedeep.training import Trainer\n", - "from pytorch_widedeep.models import TabTransformer, SAINT, WideDeep\n", + "from pytorch_widedeep.models import TabTransformer, SAINT, FTTransformer, TabFastFormer, TabPerceiver, WideDeep\n", "from pytorch_widedeep.metrics import Accuracy" ] }, @@ -417,13 +417,6 @@ "X_tab = tab_preprocessor.fit_transform(df)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": 6, @@ -435,7 +428,8 @@ "tab_transformer = TabTransformer(column_idx=tab_preprocessor.column_idx,\n", " embed_input=tab_preprocessor.embeddings_input,\n", " continuous_cols=tab_preprocessor.continuous_cols, \n", - " cont_norm_layer=\"layernorm\"\n", + " cont_norm_layer=\"batchnorm\", \n", + " n_blocks=4, n_heads=4 \n", " )" ] }, @@ -443,64 +437,27 @@ "cell_type": "code", "execution_count": 7, "metadata": { - "scrolled": false + "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "TabTransformer(\n", - " (cat_embed): Embedding(103, 32, padding_idx=0)\n", - " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", - " (cont_norm): LayerNorm((5,), eps=1e-05, elementwise_affine=True)\n", - " (transformer_blks): Sequential(\n", - " (block0): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", - " )\n", - " (ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", - " )\n", - " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - " (block1): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", - " )\n", - " (ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", - " )\n", - " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", + " (cat_and_cont_embed): CatAndContEmbeddings(\n", + " (cat_embed): CategoricalEmbeddings(\n", + " (embed): Embedding(103, 32, padding_idx=0)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", - " (block2): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (cont_norm): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (transformer_blks): Sequential(\n", + " (transformer_block0): TransformerEncoder(\n", + " (attn): MultiHeadedAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", " (ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", @@ -509,7 +466,7 @@ " (activation): GELU()\n", " )\n", " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (ff_addnorm): AddNorm(\n", @@ -517,11 +474,12 @@ " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", - " (block3): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (transformer_block1): TransformerEncoder(\n", + " (attn): MultiHeadedAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", " (ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", @@ -530,7 +488,7 @@ " (activation): GELU()\n", " )\n", " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (ff_addnorm): AddNorm(\n", @@ -538,11 +496,12 @@ " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", - " (block4): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (transformer_block2): TransformerEncoder(\n", + " (attn): MultiHeadedAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", " (ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", @@ -551,7 +510,7 @@ " (activation): GELU()\n", " )\n", " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (ff_addnorm): AddNorm(\n", @@ -559,11 +518,12 @@ " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", - " (block5): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (transformer_block3): TransformerEncoder(\n", + " (attn): MultiHeadedAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", " (ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", @@ -572,7 +532,7 @@ " (activation): GELU()\n", " )\n", " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (ff_addnorm): AddNorm(\n", @@ -634,8 +594,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 153/153 [00:23<00:00, 6.54it/s, loss=0.375, metrics={'acc': 0.8247}]\n", - "valid: 100%|██████████| 39/39 [00:01<00:00, 19.74it/s, loss=0.356, metrics={'acc': 0.8375}]\n" + "epoch 1: 100%|██████████| 153/153 [00:14<00:00, 10.56it/s, loss=0.356, metrics={'acc': 0.8321}]\n", + "valid: 100%|██████████| 39/39 [00:01<00:00, 36.52it/s, loss=0.336, metrics={'acc': 0.8465}]\n" ] } ], @@ -647,7 +607,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can also choose to use the `FT-Transformer` (just set `embed_continuous=True`), where continuous cols are also represented by \"Embeddings\", via a 1 layer MLP (with or without activation function). When using the `FT-Transformer` we can choose to use the `[CLS]` token as a pooling method or concatenate the output from the transformer blocks, as we did before. Let's use here the `[CLS]` token." + "We can also choose to use the `FT-Transformer`, where continuous cols are also represented by \"Embeddings\", via a 1 layer MLP (with or without activation function). When using the `FT-Transformer` we can choose to use the `[CLS]` token as a pooling method or concatenate the output from the transformer blocks, as we did before. Let's use here the `[CLS]` token. Also note that under the hood, the `FT-Transformer` uses Linear Attention. See [Linformer: Self-Attention with Linear Complexity](https://arxiv.org/pdf/2006.04768.pdf)" ] }, { @@ -670,170 +630,94 @@ "metadata": {}, "outputs": [], "source": [ - "# here all categorical columns will be encoded as 32 dim embeddings, then passed through the transformer \n", - "# blocks, concatenated with the continuous and finally through an MLP\n", - "ft_transformer = TabTransformer(column_idx=tab_preprocessor.column_idx,\n", - " embed_input=tab_preprocessor.embeddings_input,\n", - " continuous_cols=tab_preprocessor.continuous_cols, \n", - " embed_continuous=True,\n", - " embed_continuous_activation=None, \n", - " )" + "ft_transformer = FTTransformer(column_idx=tab_preprocessor.column_idx,\n", + " embed_input=tab_preprocessor.embeddings_input,\n", + " continuous_cols=tab_preprocessor.continuous_cols, \n", + " n_blocks=3, n_heads=6, input_dim=36\n", + " )" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { - "scrolled": false + "scrolled": true }, "outputs": [ { "data": { "text/plain": [ - "TabTransformer(\n", - " (cat_embed): Embedding(104, 32, padding_idx=0)\n", - " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", - " (cont_norm): Identity()\n", - " (cont_embed): ContinuousEmbeddings()\n", - " (transformer_blks): Sequential(\n", - " (block0): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", - " )\n", - " (ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", - " )\n", - " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", + "FTTransformer(\n", + " (cat_and_cont_embed): CatAndContEmbeddings(\n", + " (cat_embed): CategoricalEmbeddings(\n", + " (embed): Embedding(104, 36, padding_idx=0)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", - " (block1): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (cont_norm): Identity()\n", + " (cont_embed): ContinuousEmbeddings()\n", + " )\n", + " (transformer_blks): Sequential(\n", + " (fttransformer_block0): FTTransformerEncoder(\n", + " (attn): LinearAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (qkv_proj): Linear(in_features=36, out_features=108, bias=False)\n", + " (out_proj): Linear(in_features=36, out_features=36, bias=False)\n", " )\n", " (ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", + " (w_1): Linear(in_features=36, out_features=94, bias=True)\n", + " (w_2): Linear(in_features=47, out_features=36, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", + " (activation): REGLU()\n", " )\n", - " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", + " (attn_normadd): NormAdd(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (ln): LayerNorm((36,), eps=1e-05, elementwise_affine=True)\n", " )\n", - " (ff_addnorm): AddNorm(\n", + " (ff_normadd): NormAdd(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", + " (ln): LayerNorm((36,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", - " (block2): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (fttransformer_block1): FTTransformerEncoder(\n", + " (attn): LinearAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (qkv_proj): Linear(in_features=36, out_features=108, bias=False)\n", + " (out_proj): Linear(in_features=36, out_features=36, bias=False)\n", " )\n", " (ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", + " (w_1): Linear(in_features=36, out_features=94, bias=True)\n", + " (w_2): Linear(in_features=47, out_features=36, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", + " (activation): REGLU()\n", " )\n", - " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", + " (attn_normadd): NormAdd(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (ln): LayerNorm((36,), eps=1e-05, elementwise_affine=True)\n", " )\n", - " (ff_addnorm): AddNorm(\n", + " (ff_normadd): NormAdd(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", + " (ln): LayerNorm((36,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", - " (block3): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (fttransformer_block2): FTTransformerEncoder(\n", + " (attn): LinearAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (qkv_proj): Linear(in_features=36, out_features=108, bias=False)\n", + " (out_proj): Linear(in_features=36, out_features=36, bias=False)\n", " )\n", " (ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", + " (w_1): Linear(in_features=36, out_features=94, bias=True)\n", + " (w_2): Linear(in_features=47, out_features=36, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", + " (activation): REGLU()\n", " )\n", - " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", + " (attn_normadd): NormAdd(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (ln): LayerNorm((36,), eps=1e-05, elementwise_affine=True)\n", " )\n", - " (ff_addnorm): AddNorm(\n", + " (ff_normadd): NormAdd(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - " (block4): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", - " )\n", - " (ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", - " )\n", - " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - " (block5): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", - " )\n", - " (ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", - " )\n", - " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - " )\n", - " (transformer_mlp): MLP(\n", - " (mlp): Sequential(\n", - " (dense_layer_0): Sequential(\n", - " (0): Linear(in_features=32, out_features=128, bias=True)\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (dense_layer_1): Sequential(\n", - " (0): Linear(in_features=128, out_features=64, bias=True)\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", + " (ln): LayerNorm((36,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", " )\n", @@ -876,8 +760,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 153/153 [00:38<00:00, 4.02it/s, loss=0.412, metrics={'acc': 0.8002}]\n", - "valid: 100%|██████████| 39/39 [00:02<00:00, 14.69it/s, loss=0.326, metrics={'acc': 0.8524}]\n" + "epoch 1: 100%|██████████| 153/153 [00:15<00:00, 9.62it/s, loss=0.382, metrics={'acc': 0.8167}]\n", + "valid: 100%|██████████| 39/39 [00:01<00:00, 28.84it/s, loss=0.317, metrics={'acc': 0.8566}]\n" ] } ], @@ -894,7 +778,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -902,56 +786,61 @@ " embed_input=tab_preprocessor.embeddings_input,\n", " continuous_cols=tab_preprocessor.continuous_cols, \n", " transformer_activation=\"geglu\",\n", - " embed_continuous=True,\n", - " embed_continuous_activation=None, \n", + " n_blocks=2, n_heads=4, \n", " )" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 18, "metadata": { - "scrolled": false + "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "SAINT(\n", - " (cat_embed): Embedding(104, 32, padding_idx=0)\n", - " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", - " (cont_norm): LayerNorm((5,), eps=1e-05, elementwise_affine=True)\n", - " (cont_embed): ContinuousEmbeddings()\n", + " (cat_and_cont_embed): CatAndContEmbeddings(\n", + " (cat_embed): CategoricalEmbeddings(\n", + " (embed): Embedding(104, 32, padding_idx=0)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (cont_norm): Identity()\n", + " (cont_embed): ContinuousEmbeddings()\n", + " )\n", " (transformer_blks): Sequential(\n", - " (block0): SaintEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", + " (saint_block0): SaintEncoder(\n", + " (col_attn): MultiHeadedAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", - " (self_attn_ff): PositionwiseFF(\n", + " (col_attn_ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=256, bias=True)\n", " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): GEGLU()\n", " )\n", - " (self_attn_addnorm): AddNorm(\n", + " (col_attn_addnorm): AddNorm(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", - " (self_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (col_attn_ff_addnorm): AddNorm(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (row_attn): MultiHeadedAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=448, out_features=1344, bias=True)\n", - " (out_proj): Linear(in_features=448, out_features=448, bias=True)\n", + " (q_proj): Linear(in_features=448, out_features=448, bias=False)\n", + " (kv_proj): Linear(in_features=448, out_features=896, bias=False)\n", + " (out_proj): Linear(in_features=448, out_features=448, bias=False)\n", " )\n", " (row_attn_ff): PositionwiseFF(\n", " (w_1): Linear(in_features=448, out_features=3584, bias=True)\n", " (w_2): Linear(in_features=1792, out_features=448, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): GEGLU()\n", " )\n", " (row_attn_addnorm): AddNorm(\n", @@ -959,39 +848,41 @@ " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (row_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", - " (block1): SaintEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", + " (saint_block1): SaintEncoder(\n", + " (col_attn): MultiHeadedAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", - " (self_attn_ff): PositionwiseFF(\n", + " (col_attn_ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=256, bias=True)\n", " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): GEGLU()\n", " )\n", - " (self_attn_addnorm): AddNorm(\n", + " (col_attn_addnorm): AddNorm(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", - " (self_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (col_attn_ff_addnorm): AddNorm(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (row_attn): MultiHeadedAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=448, out_features=1344, bias=True)\n", - " (out_proj): Linear(in_features=448, out_features=448, bias=True)\n", + " (q_proj): Linear(in_features=448, out_features=448, bias=False)\n", + " (kv_proj): Linear(in_features=448, out_features=896, bias=False)\n", + " (out_proj): Linear(in_features=448, out_features=448, bias=False)\n", " )\n", " (row_attn_ff): PositionwiseFF(\n", " (w_1): Linear(in_features=448, out_features=3584, bias=True)\n", " (w_2): Linear(in_features=1792, out_features=448, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): GEGLU()\n", " )\n", " (row_attn_addnorm): AddNorm(\n", @@ -999,169 +890,150 @@ " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (row_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", - " (block2): SaintEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", - " )\n", - " (self_attn_ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=256, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GEGLU()\n", - " )\n", - " (self_attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (self_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (row_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=448, out_features=1344, bias=True)\n", - " (out_proj): Linear(in_features=448, out_features=448, bias=True)\n", - " )\n", - " (row_attn_ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=448, out_features=3584, bias=True)\n", - " (w_2): Linear(in_features=1792, out_features=448, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GEGLU()\n", - " )\n", - " (row_attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (transformer_mlp): MLP(\n", + " (mlp): Sequential(\n", + " (dense_layer_0): Sequential(\n", + " (0): Linear(in_features=32, out_features=128, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", " )\n", - " (row_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n", + " (dense_layer_1): Sequential(\n", + " (0): Linear(in_features=128, out_features=64, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", - " (block3): SaintEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", - " )\n", - " (self_attn_ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=256, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GEGLU()\n", - " )\n", - " (self_attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (self_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (row_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=448, out_features=1344, bias=True)\n", - " (out_proj): Linear(in_features=448, out_features=448, bias=True)\n", - " )\n", - " (row_attn_ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=448, out_features=3584, bias=True)\n", - " (w_2): Linear(in_features=1792, out_features=448, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GEGLU()\n", - " )\n", - " (row_attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (row_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n", - " )\n", + " )\n", + ")" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "saint" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "epoch 1: 100%|██████████| 306/306 [00:47<00:00, 6.42it/s, loss=0.377, metrics={'acc': 0.8224}]\n", + "valid: 100%|██████████| 77/77 [00:02<00:00, 32.20it/s, loss=0.338, metrics={'acc': 0.8529}]\n" + ] + } + ], + "source": [ + "model = WideDeep(deeptabular=saint)\n", + "trainer = Trainer(model, objective='binary', metrics=[Accuracy])\n", + "trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=128, val_split=0.2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The previous models have all been published. The following two are adaptations of existing Transformer models for tabular data and by the time I am writing this they are only available in this library. If I have the time I will write a post about their implementation. Nonetheless, all the details can be found in the [docs](https://pytorch-widedeep.readthedocs.io/en/latest/index.html).\n", + "\n", + "The first one is an adaptation of [Fastformer: Additive Attention Can Be All You Need](https://arxiv.org/pdf/2108.09084.pdf). I have mixed feelings towards that paper, that I will not be covering here, but you can go and watch [Yannic's video](https://www.youtube.com/watch?v=qgUegkefocg&t=1s) since most of my opinions are also explained there. Nonetheless, the reason to bring this model to the library is because in essence, the `FastFormer` is an \"elaborated MLP\" with an \"interesting\" attention aggregated attention mechanism. Since MLPs work really well for tabular data compared to other, more complex models, why not add it to the library. \n", + "\n", + "To use it, just follow the same routine as with any other transformer-based model" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "tabfastformer = TabFastFormer(column_idx=tab_preprocessor.column_idx,\n", + " embed_input=tab_preprocessor.embeddings_input,\n", + " continuous_cols=tab_preprocessor.continuous_cols, \n", + " n_blocks=2, n_heads=4, \n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "TabFastFormer(\n", + " (cat_and_cont_embed): CatAndContEmbeddings(\n", + " (cat_embed): CategoricalEmbeddings(\n", + " (embed): Embedding(104, 32, padding_idx=0)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", - " (block4): SaintEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", + " (cont_norm): Identity()\n", + " (cont_embed): ContinuousEmbeddings()\n", + " )\n", + " (transformer_blks): Sequential(\n", + " (fastformer_block0): FastFormerEncoder(\n", + " (attn): AdditiveAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (v_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (k_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (W_q): Linear(in_features=8, out_features=1, bias=False)\n", + " (W_k): Linear(in_features=8, out_features=1, bias=False)\n", + " (r_out): Linear(in_features=8, out_features=8, bias=True)\n", " )\n", - " (self_attn_ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=256, bias=True)\n", + " (ff): PositionwiseFF(\n", + " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GEGLU()\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (activation): ReLU(inplace=True)\n", " )\n", - " (self_attn_addnorm): AddNorm(\n", + " (attn_addnorm): AddNorm(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", - " (self_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (ff_addnorm): AddNorm(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", - " (row_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=448, out_features=1344, bias=True)\n", - " (out_proj): Linear(in_features=448, out_features=448, bias=True)\n", - " )\n", - " (row_attn_ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=448, out_features=3584, bias=True)\n", - " (w_2): Linear(in_features=1792, out_features=448, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GEGLU()\n", - " )\n", - " (row_attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (row_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n", - " )\n", " )\n", - " (block5): SaintEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", + " (fastformer_block1): FastFormerEncoder(\n", + " (attn): AdditiveAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (v_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (k_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (W_q): Linear(in_features=8, out_features=1, bias=False)\n", + " (W_k): Linear(in_features=8, out_features=1, bias=False)\n", + " (r_out): Linear(in_features=8, out_features=8, bias=True)\n", " )\n", - " (self_attn_ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=256, bias=True)\n", + " (ff): PositionwiseFF(\n", + " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GEGLU()\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (activation): ReLU(inplace=True)\n", " )\n", - " (self_attn_addnorm): AddNorm(\n", + " (attn_addnorm): AddNorm(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", - " (self_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (ff_addnorm): AddNorm(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", - " (row_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=448, out_features=1344, bias=True)\n", - " (out_proj): Linear(in_features=448, out_features=448, bias=True)\n", - " )\n", - " (row_attn_ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=448, out_features=3584, bias=True)\n", - " (w_2): Linear(in_features=1792, out_features=448, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GEGLU()\n", - " )\n", - " (row_attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (row_attn_ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n", - " )\n", " )\n", " )\n", " (transformer_mlp): MLP(\n", @@ -1181,47 +1053,107 @@ ")" ] }, - "execution_count": 19, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "saint" + "tabfastformer" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 306/306 [02:34<00:00, 1.98it/s, loss=0.555, metrics={'acc': 0.7607}]\n", - "valid: 100%|██████████| 77/77 [00:08<00:00, 8.96it/s, loss=0.551, metrics={'acc': 0.7607}]\n" + "epoch 1: 100%|██████████| 153/153 [00:10<00:00, 14.58it/s, loss=0.46, metrics={'acc': 0.7867}] \n", + "valid: 100%|██████████| 39/39 [00:00<00:00, 48.19it/s, loss=0.342, metrics={'acc': 0.8443}]\n" ] } ], "source": [ - "model = WideDeep(deeptabular=saint)\n", + "model = WideDeep(deeptabular=tabfastformer)\n", "trainer = Trainer(model, objective='binary', metrics=[Accuracy])\n", - "trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=128, val_split=0.2)" + "trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=256, val_split=0.2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And finally, the last of the transformer-based models that are currently available in the library is DeepMind's [Perceiver](https://arxiv.org/pdf/2103.03206.pdf). The reason to add this model to the library is the following. The Perceiver is meant to be an architecture agnostic of the nature of the input data, i.e. it is meant to work with audio, images, text...So why not tabular, right? \n", + "\n", + "To use it...you guessed right! " + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "tab_preprocessor = TabPreprocessor(embed_cols=cat_cols, \n", + " continuous_cols=cont_cols, \n", + " for_transformer=True,\n", + " )\n", + "X_tab = tab_preprocessor.fit_transform(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "tabperceiver = TabPerceiver(\n", + " column_idx=tab_preprocessor.column_idx,\n", + " embed_input=tab_preprocessor.embeddings_input,\n", + " continuous_cols=tab_preprocessor.continuous_cols, \n", + " n_perceiver_blocks=1, \n", + " n_latent_blocks=3, \n", + " n_latent_heads=2, \n", + " n_latents=6,\n", + " latent_dim=32,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "epoch 1: 100%|██████████| 153/153 [00:16<00:00, 9.45it/s, loss=0.4, metrics={'acc': 0.81}] \n", + "valid: 100%|██████████| 39/39 [00:01<00:00, 37.95it/s, loss=0.323, metrics={'acc': 0.8542}]\n" + ] + } + ], + "source": [ + "model = WideDeep(deeptabular=tabperceiver)\n", + "trainer = Trainer(model, objective='binary', metrics=[Accuracy])\n", + "trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=256, val_split=0.2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "One final comment is that all 3 transformer-based model have the option of using the so called \"Shared Embeddings\". The idea behind the shared embeddings is explained in the original TabTransformer paper and also here in this [post](https://jrzaurin.github.io/infinitoml/2021/02/18/pytorch-widedeep_iii.html).\n", + "One final comment is that all transformer-based models have the option of using the so called \"Shared Embeddings\". The idea behind the shared embeddings is explained in the original TabTransformer paper and also here in this [post](https://jrzaurin.github.io/infinitoml/2021/02/18/pytorch-widedeep_iii.html).\n", "\n", "For transformer-based models this implies a bit of a different data preparation process since each column will be encoded individually (programmatically is way easier to implement) and the use of shared embeddings needs to be specified at preprocessing stage" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -1236,7 +1168,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -1246,12 +1178,14 @@ " embed_continuous=True,\n", " embed_continuous_activation=None, \n", " shared_embed=True, \n", + " cont_norm_layer=\"batchnorm\", \n", + " n_blocks=4, n_heads=4 \n", " )" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 28, "metadata": { "scrolled": false }, @@ -1260,73 +1194,57 @@ "data": { "text/plain": [ "TabTransformer(\n", - " (cat_embed): ModuleDict(\n", - " (emb_layer_cls_token): SharedEmbeddings(\n", - " (embed): Embedding(1, 32, padding_idx=0)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (emb_layer_education): SharedEmbeddings(\n", - " (embed): Embedding(17, 32, padding_idx=0)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (emb_layer_gender): SharedEmbeddings(\n", - " (embed): Embedding(3, 32, padding_idx=0)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (emb_layer_marital_status): SharedEmbeddings(\n", - " (embed): Embedding(8, 32, padding_idx=0)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (emb_layer_native_country): SharedEmbeddings(\n", - " (embed): Embedding(43, 32, padding_idx=0)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (emb_layer_occupation): SharedEmbeddings(\n", - " (embed): Embedding(16, 32, padding_idx=0)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (emb_layer_race): SharedEmbeddings(\n", - " (embed): Embedding(6, 32, padding_idx=0)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (emb_layer_relationship): SharedEmbeddings(\n", - " (embed): Embedding(7, 32, padding_idx=0)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (emb_layer_workclass): SharedEmbeddings(\n", - " (embed): Embedding(10, 32, padding_idx=0)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (cat_and_cont_embed): CatAndContEmbeddings(\n", + " (cat_embed): CategoricalEmbeddings(\n", + " (embed): ModuleDict(\n", + " (emb_layer_cls_token): SharedEmbeddings(\n", + " (embed): Embedding(1, 32, padding_idx=0)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (emb_layer_education): SharedEmbeddings(\n", + " (embed): Embedding(17, 32, padding_idx=0)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (emb_layer_gender): SharedEmbeddings(\n", + " (embed): Embedding(3, 32, padding_idx=0)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (emb_layer_marital_status): SharedEmbeddings(\n", + " (embed): Embedding(8, 32, padding_idx=0)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (emb_layer_native_country): SharedEmbeddings(\n", + " (embed): Embedding(43, 32, padding_idx=0)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (emb_layer_occupation): SharedEmbeddings(\n", + " (embed): Embedding(16, 32, padding_idx=0)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (emb_layer_race): SharedEmbeddings(\n", + " (embed): Embedding(6, 32, padding_idx=0)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (emb_layer_relationship): SharedEmbeddings(\n", + " (embed): Embedding(7, 32, padding_idx=0)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (emb_layer_workclass): SharedEmbeddings(\n", + " (embed): Embedding(10, 32, padding_idx=0)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", " )\n", + " (cont_norm): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (cont_embed): ContinuousEmbeddings()\n", " )\n", - " (cont_norm): Identity()\n", - " (cont_embed): ContinuousEmbeddings()\n", " (transformer_blks): Sequential(\n", - " (block0): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", - " )\n", - " (ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", - " )\n", - " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - " (block1): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (transformer_block0): TransformerEncoder(\n", + " (attn): MultiHeadedAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", " (ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", @@ -1335,7 +1253,7 @@ " (activation): GELU()\n", " )\n", " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (ff_addnorm): AddNorm(\n", @@ -1343,11 +1261,12 @@ " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", - " (block2): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (transformer_block1): TransformerEncoder(\n", + " (attn): MultiHeadedAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", " (ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", @@ -1356,7 +1275,7 @@ " (activation): GELU()\n", " )\n", " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (ff_addnorm): AddNorm(\n", @@ -1364,11 +1283,12 @@ " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", - " (block3): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (transformer_block2): TransformerEncoder(\n", + " (attn): MultiHeadedAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", " (ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", @@ -1377,7 +1297,7 @@ " (activation): GELU()\n", " )\n", " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (ff_addnorm): AddNorm(\n", @@ -1385,11 +1305,12 @@ " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", - " (block4): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", + " (transformer_block3): TransformerEncoder(\n", + " (attn): MultiHeadedAttention(\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n", + " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n", + " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n", " )\n", " (ff): PositionwiseFF(\n", " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", @@ -1398,28 +1319,7 @@ " (activation): GELU()\n", " )\n", " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (ff_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - " (block5): TransformerEncoder(\n", - " (self_attn): MultiHeadedAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n", - " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n", - " )\n", - " (ff): PositionwiseFF(\n", - " (w_1): Linear(in_features=32, out_features=128, bias=True)\n", - " (w_2): Linear(in_features=128, out_features=32, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (activation): GELU()\n", - " )\n", - " (attn_addnorm): AddNorm(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (ff_addnorm): AddNorm(\n", @@ -1445,7 +1345,7 @@ ")" ] }, - "execution_count": 23, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -1456,15 +1356,15 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 153/153 [00:35<00:00, 4.32it/s, loss=0.411, metrics={'acc': 0.8036}]\n", - "valid: 100%|██████████| 39/39 [00:02<00:00, 15.03it/s, loss=0.319, metrics={'acc': 0.8583}]\n" + "epoch 1: 100%|██████████| 153/153 [00:20<00:00, 7.62it/s, loss=0.4, metrics={'acc': 0.8061}] \n", + "valid: 100%|██████████| 39/39 [00:01<00:00, 30.53it/s, loss=0.324, metrics={'acc': 0.8551}]\n" ] } ], diff --git a/examples/11_Extracting_Embeddings.ipynb b/examples/11_Extracting_Embeddings.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..5504f84dfb8f6c254e2d847cbf2afd7d02c43d67 --- /dev/null +++ b/examples/11_Extracting_Embeddings.ipynb @@ -0,0 +1,513 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/javier/.pyenv/versions/3.7.7/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n", + " return f(*args, **kwds)\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "\n", + "from pytorch_widedeep.preprocessing import TabPreprocessor\n", + "from pytorch_widedeep.training import Trainer\n", + "from pytorch_widedeep.models import FTTransformer, WideDeep\n", + "from pytorch_widedeep.metrics import Accuracy\n", + "from pytorch_widedeep import Tab2Vec" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclassfnlwgteducationeducational-nummarital-statusoccupationrelationshipracegendercapital-gaincapital-losshours-per-weeknative-countryincome
025Private22680211th7Never-marriedMachine-op-inspctOwn-childBlackMale0040United-States<=50K
138Private89814HS-grad9Married-civ-spouseFarming-fishingHusbandWhiteMale0050United-States<=50K
228Local-gov336951Assoc-acdm12Married-civ-spouseProtective-servHusbandWhiteMale0040United-States>50K
344Private160323Some-college10Married-civ-spouseMachine-op-inspctHusbandBlackMale7688040United-States>50K
418?103497Some-college10Never-married?Own-childWhiteFemale0030United-States<=50K
\n", + "
" + ], + "text/plain": [ + " age workclass fnlwgt education educational-num marital-status \\\n", + "0 25 Private 226802 11th 7 Never-married \n", + "1 38 Private 89814 HS-grad 9 Married-civ-spouse \n", + "2 28 Local-gov 336951 Assoc-acdm 12 Married-civ-spouse \n", + "3 44 Private 160323 Some-college 10 Married-civ-spouse \n", + "4 18 ? 103497 Some-college 10 Never-married \n", + "\n", + " occupation relationship race gender capital-gain capital-loss \\\n", + "0 Machine-op-inspct Own-child Black Male 0 0 \n", + "1 Farming-fishing Husband White Male 0 0 \n", + "2 Protective-serv Husband White Male 0 0 \n", + "3 Machine-op-inspct Husband Black Male 7688 0 \n", + "4 ? Own-child White Female 0 0 \n", + "\n", + " hours-per-week native-country income \n", + "0 40 United-States <=50K \n", + "1 50 United-States <=50K \n", + "2 40 United-States >50K \n", + "3 40 United-States >50K \n", + "4 30 United-States <=50K " + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_csv('data/adult/adult.csv.zip')\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclassfnlwgteducationmarital_statusoccupationrelationshipracegendercapital_gaincapital_losshours_per_weeknative_countrytarget
025Private22680211thNever-marriedMachine-op-inspctOwn-childBlackMale0040United-States0
138Private89814HS-gradMarried-civ-spouseFarming-fishingHusbandWhiteMale0050United-States0
228Local-gov336951Assoc-acdmMarried-civ-spouseProtective-servHusbandWhiteMale0040United-States1
344Private160323Some-collegeMarried-civ-spouseMachine-op-inspctHusbandBlackMale7688040United-States1
418?103497Some-collegeNever-married?Own-childWhiteFemale0030United-States0
\n", + "
" + ], + "text/plain": [ + " age workclass fnlwgt education marital_status \\\n", + "0 25 Private 226802 11th Never-married \n", + "1 38 Private 89814 HS-grad Married-civ-spouse \n", + "2 28 Local-gov 336951 Assoc-acdm Married-civ-spouse \n", + "3 44 Private 160323 Some-college Married-civ-spouse \n", + "4 18 ? 103497 Some-college Never-married \n", + "\n", + " occupation relationship race gender capital_gain capital_loss \\\n", + "0 Machine-op-inspct Own-child Black Male 0 0 \n", + "1 Farming-fishing Husband White Male 0 0 \n", + "2 Protective-serv Husband White Male 0 0 \n", + "3 Machine-op-inspct Husband Black Male 7688 0 \n", + "4 ? Own-child White Female 0 0 \n", + "\n", + " hours_per_week native_country target \n", + "0 40 United-States 0 \n", + "1 50 United-States 0 \n", + "2 40 United-States 1 \n", + "3 40 United-States 1 \n", + "4 30 United-States 0 " + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# For convenience, we'll replace '-' with '_'\n", + "df.columns = [c.replace(\"-\", \"_\") for c in df.columns]\n", + "#binary target\n", + "df['target'] = (df[\"income\"].apply(lambda x: \">50K\" in x)).astype(int)\n", + "df.drop([\"income\", \"educational_num\"], axis=1, inplace=True)\n", + "\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "cat_cols, cont_cols = [], []\n", + "for col in df.columns:\n", + " # 50 is just a random number I choose here for this example\n", + " if df[col].dtype == \"O\" or df[col].nunique() < 50 and col != \"target\":\n", + " cat_cols.append(col)\n", + " elif col != \"target\": \n", + " cont_cols.append(col)\n", + "target_col = \"target\"" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "target = df[target_col].values\n", + "\n", + "tab_preprocessor = TabPreprocessor(embed_cols=cat_cols, \n", + " continuous_cols=cont_cols, \n", + " for_transformer=True\n", + " )\n", + "X_tab = tab_preprocessor.fit_transform(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "ft_transformer = FTTransformer(column_idx=tab_preprocessor.column_idx,\n", + " embed_input=tab_preprocessor.embeddings_input,\n", + " continuous_cols=tab_preprocessor.continuous_cols, \n", + " n_blocks=3, n_heads=6, input_dim=36\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "epoch 1: 100%|██████████| 153/153 [00:15<00:00, 10.12it/s, loss=0.355, metrics={'acc': 0.8326}]\n", + "valid: 100%|██████████| 39/39 [00:01<00:00, 26.68it/s, loss=0.308, metrics={'acc': 0.8598}]\n" + ] + } + ], + "source": [ + "model = WideDeep(deeptabular=ft_transformer)\n", + "trainer = Trainer(model, objective='binary', metrics=[Accuracy])\n", + "trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=256, val_split=0.2)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "t2v = Tab2Vec(model=model, tab_preprocessor=tab_preprocessor)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# assuming is a test set with target col\n", + "X_vec, y = t2v.transform(df.sample(100), target_col=\"target\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(100, 468)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# X vec is the dataframe turned into the embeddings\n", + "X_vec.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`468 = input_dim (36) * n_cols (13)`" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# ...or if we don't have target col\n", + "X_vec = t2v.transform(df.sample(100))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/adult_census_transformers.py b/examples/adult_census_transformers.py index 5973e9acb96399c6494b99640c46ac1ef4d53a0f..9bef4ded14978a0efbe749be8a313d146d990e57 100644 --- a/examples/adult_census_transformers.py +++ b/examples/adult_census_transformers.py @@ -3,8 +3,15 @@ import torch import pandas as pd from pytorch_widedeep import Trainer -from pytorch_widedeep.optim import RAdam -from pytorch_widedeep.models import SAINT, Wide, WideDeep, TabTransformer +from pytorch_widedeep.models import ( + SAINT, + Wide, + WideDeep, + TabPerceiver, + FTTransformer, + TabFastFormer, + TabTransformer, +) from pytorch_widedeep.metrics import Accuracy from pytorch_widedeep.callbacks import ( LRHistory, @@ -64,22 +71,60 @@ if __name__ == "__main__": embed_input=prepare_deep.embeddings_input, continuous_cols=continuous_cols, embed_continuous=True, + n_blocks=4, ) saint = SAINT( column_idx=prepare_deep.column_idx, embed_input=prepare_deep.embeddings_input, continuous_cols=continuous_cols, - embed_continuous=True, cont_norm_layer="batchnorm", + n_blocks=4, + ) + + tab_perceiver = TabPerceiver( + column_idx=prepare_deep.column_idx, + embed_input=prepare_deep.embeddings_input, + continuous_cols=continuous_cols, + n_latents=6, + latent_dim=16, + n_latent_blocks=4, + n_perceiver_blocks=2, + share_weights=False, + ) + + tab_fastformer = TabFastFormer( + column_idx=prepare_deep.column_idx, + embed_input=prepare_deep.embeddings_input, + continuous_cols=continuous_cols, + n_blocks=4, + n_heads=4, + share_qv_weights=False, + share_weights=False, + ) + + ft_transformer = FTTransformer( + column_idx=prepare_deep.column_idx, + embed_input=prepare_deep.embeddings_input, + continuous_cols=continuous_cols, + input_dim=32, + kv_compression_factor=0.5, + n_blocks=3, + n_heads=4, ) - for tab_model in [tab_transformer, saint]: + for tab_model in [ + tab_transformer, + saint, + ft_transformer, + tab_perceiver, + tab_fastformer, + ]: model = WideDeep(wide=wide, deeptabular=tab_model) wide_opt = torch.optim.Adam(model.wide.parameters(), lr=0.01) - deep_opt = RAdam(model.deeptabular.parameters()) + deep_opt = torch.optim.Adam(model.wide.parameters(), lr=0.01) wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=3) deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=5) diff --git a/pypi_README.md b/pypi_README.md index 12f0e4d48f272b6a1af52c1cdeef96e21d59822d..f75773ab72883000eca55bf07548c338608c20cc 100644 --- a/pypi_README.md +++ b/pypi_README.md @@ -4,7 +4,7 @@ [![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity) [![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues) [![codecov](https://codecov.io/gh/jrzaurin/pytorch-widedeep/branch/master/graph/badge.svg)](https://codecov.io/gh/jrzaurin/pytorch-widedeep) -[![Python 3.6 3.7 3.8 3.9](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8%20%7C%203.9-blue.svg)](https://www.python.org/) +[![Python 3.7 3.8 3.9](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9-blue.svg)](https://www.python.org/) # pytorch-widedeep @@ -63,8 +63,8 @@ when running on Mac, present in previous versions, persist on this release and the data-loaders will not run in parallel. In addition, since `python 3.8`, [the `multiprocessing` library start method changed from `'fork'` to`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods). This also affects the data-loaders (for any `torch` version) and they will -not run in parallel. Therefore, for Mac users I recommend using `python -3.6` or `3.7` and `torch <= 1.6` (with the corresponding, consistent +not run in parallel. Therefore, for Mac users I recommend using `python 3.7` +and `torch <= 1.6` (with the corresponding, consistent version of `torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to force this versioning in the `setup.py` file since I expect that all these issues are fixed in the future. Therefore, after installing diff --git a/pytorch_widedeep/__init__.py b/pytorch_widedeep/__init__.py index 1a06801695ca59a7916501a079d9c50ed77f4525..cdcad2eb076ddaa296cc057b838a634717b3fd97 100644 --- a/pytorch_widedeep/__init__.py +++ b/pytorch_widedeep/__init__.py @@ -12,5 +12,6 @@ from pytorch_widedeep.utils import ( deeptabular_utils, fastai_transforms, ) +from pytorch_widedeep.tab2vec import Tab2Vec from pytorch_widedeep.version import __version__ from pytorch_widedeep.training import Trainer diff --git a/pytorch_widedeep/callbacks.py b/pytorch_widedeep/callbacks.py index 4fd4a7c6f0bed4b100d1378338f80369d927e566..f39b6d5724982302c64375bd6fc3e1bb3a04dea9 100644 --- a/pytorch_widedeep/callbacks.py +++ b/pytorch_widedeep/callbacks.py @@ -341,19 +341,19 @@ class ModelCheckpoint(Callback): weights_out_2.pt, ...`` monitor: str, default="loss" quantity to monitor. Typically 'val_loss' or metric name (e.g. 'val_acc') - verbose:int, default=0, + verbose:int, default=0 verbosity mode save_best_only: bool, default=False, the latest best model according to the quantity monitored will not be overwritten. - mode: str, default="auto", + mode: str, default="auto" If ``save_best_only=True``, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For `'acc'`, this should be `'max'`, for `'loss'` this should be `'min'`, etc. In `'auto'` mode, the direction is automatically inferred from the name of the monitored quantity. - period: int, default=1, + period: int, default=1 Interval (number of epochs) between checkpoints. max_save: int, default=-1 Maximum number of outputs to save. If -1 will save all outputs @@ -425,11 +425,11 @@ class ModelCheckpoint(Callback): self.monitor_op = np.less self.best = np.Inf elif self.mode == "max": - self.monitor_op = np.greater + self.monitor_op = np.greater # type: ignore[assignment] self.best = -np.Inf else: if _is_metric(self.monitor): - self.monitor_op = np.greater + self.monitor_op = np.greater # type: ignore[assignment] self.best = -np.Inf else: self.monitor_op = np.less @@ -596,10 +596,10 @@ class EarlyStopping(Callback): if self.mode == "min": self.monitor_op = np.less elif self.mode == "max": - self.monitor_op = np.greater + self.monitor_op = np.greater # type: ignore[assignment] else: if _is_metric(self.monitor): - self.monitor_op = np.greater + self.monitor_op = np.greater # type: ignore[assignment] else: self.monitor_op = np.less diff --git a/pytorch_widedeep/models/__init__.py b/pytorch_widedeep/models/__init__.py index 85925cb0bfbf71180d2e3f6774fe9211a75dbd61..14fbba2e4026215f670b6a8ec0b77ecb4d420471 100644 --- a/pytorch_widedeep/models/__init__.py +++ b/pytorch_widedeep/models/__init__.py @@ -6,4 +6,7 @@ from pytorch_widedeep.models.deep_image import DeepImage from pytorch_widedeep.models.tab_resnet import TabResnet from pytorch_widedeep.models.tabnet.tab_net import TabNet from pytorch_widedeep.models.transformers.saint import SAINT +from pytorch_widedeep.models.transformers.tab_perceiver import TabPerceiver +from pytorch_widedeep.models.transformers.ft_transformer import FTTransformer +from pytorch_widedeep.models.transformers.tab_fastformer import TabFastFormer from pytorch_widedeep.models.transformers.tab_transformer import TabTransformer diff --git a/pytorch_widedeep/models/deep_image.py b/pytorch_widedeep/models/deep_image.py index a8ddc1a8bc621d0daaea0a4fcaed8a746a349781..218ecbe990c4679b2222061813547fa94d5c52a9 100644 --- a/pytorch_widedeep/models/deep_image.py +++ b/pytorch_widedeep/models/deep_image.py @@ -55,11 +55,12 @@ class DeepImage(nn.Module): The resnet architecture. One of 18, 34 or 50 freeze_n: int, default = 6 number of layers to freeze. Must be less than or equal to 8. If 8 - the entire 'backbone' of the nwtwork will be frozen + the entire 'backbone' of the network will be frozen head_hidden_dims: List, Optional, default = None List with the number of neurons per dense layer in the head. e.g: [64,32] head_activation: str, default = "relu" - Activation function for the dense layers in the head. + Activation function for the dense layers in the head. Currently + ``tanh``, ``relu``, ``leaky_relu`` and ``gelu`` are supported head_dropout: float, default = 0.1 float indicating the dropout between the dense layers. head_batchnorm: bool, default = False diff --git a/pytorch_widedeep/models/deep_text.py b/pytorch_widedeep/models/deep_text.py index ee298ccb186f111fe74781418bfece114e717cb7..b709efc32b6a5634e9baa80bac1a386a223fc9bb 100644 --- a/pytorch_widedeep/models/deep_text.py +++ b/pytorch_widedeep/models/deep_text.py @@ -20,7 +20,7 @@ class DeepText(nn.Module): vocab_size: int number of words in the vocabulary rnn_type: str, default = 'lstm' - String indicating the type of RNN to use. One of "lstm" or "gru" + String indicating the type of RNN to use. One of ``lstm`` or ``gru`` hidden_dim: int, default = 64 Hidden dim of the RNN n_layers: int, default = 3 @@ -30,9 +30,9 @@ class DeepText(nn.Module): the last layer bidirectional: bool, default = True indicates whether the staked RNNs are bidirectional - use_hidden_state: str, default = True, + use_hidden_state: str, default = True Boolean indicating whether to use the final hidden state or the - rnn output as predicting features + RNN output as predicting features padding_idx: int, default = 1 index of the padding token in the padded-tokenised sequences. I use the ``fastai`` tokenizer where the token index 0 is reserved @@ -48,7 +48,8 @@ class DeepText(nn.Module): List with the sizes of the stacked dense layers in the head e.g: [128, 64] head_activation: str, default = "relu" - Activation function for the dense layers in the head + Activation function for the dense layers in the head. Currently + ``tanh``, ``relu``, ``leaky_relu`` and ``gelu`` are supported head_dropout: float, Optional, default = None dropout between the dense layers in the head head_batchnorm: bool, default = False diff --git a/pytorch_widedeep/models/tab_mlp.py b/pytorch_widedeep/models/tab_mlp.py index b1b92865bd182aed701852662be56f1db8d88fc4..7c361345cd2e6fdb73051ac4c3eaa6f38786f177 100644 --- a/pytorch_widedeep/models/tab_mlp.py +++ b/pytorch_widedeep/models/tab_mlp.py @@ -5,7 +5,7 @@ from torch import nn from pytorch_widedeep.wdtypes import * # noqa: F403 -allowed_activations = ["relu", "leaky_relu", "gelu", "geglu"] +allowed_activations = ["relu", "leaky_relu", "tanh", "gelu", "geglu", "reglu"] class GEGLU(nn.Module): @@ -14,15 +14,25 @@ class GEGLU(nn.Module): return x * F.gelu(gates) -def _get_activation_fn(activation): +class REGLU(nn.Module): + def forward(self, x): + x, gates = x.chunk(2, dim=-1) + return x * F.gelu(gates) + + +def get_activation_fn(activation): if activation == "relu": return nn.ReLU(inplace=True) if activation == "leaky_relu": return nn.LeakyReLU(inplace=True) - elif activation == "gelu": + if activation == "tanh": + return nn.Tanh() + if activation == "gelu": return nn.GELU() - elif activation == "geglu": + if activation == "geglu": return GEGLU() + if activation == "reglu": + return REGLU() def dense_layer( @@ -37,9 +47,9 @@ def dense_layer( if activation == "geglu": raise ValueError( "'geglu' activation is only used as 'transformer_activation' " - "in transformer-based models (TabTransformer and SAINT)" + "in transformer-based models" ) - act_fn = _get_activation_fn(activation) + act_fn = get_activation_fn(activation) layers = [nn.BatchNorm1d(out if linear_first else inp)] if bn else [] if p != 0: layers.append(nn.Dropout(p)) # type: ignore[arg-type] @@ -48,6 +58,69 @@ def dense_layer( return nn.Sequential(*layers) +class CatEmbeddingsAndCont(nn.Module): + def __init__( + self, + column_idx: Dict[str, int], + embed_input: List[Tuple[str, int, int]], + embed_dropout: float, + continuous_cols: Optional[List[str]], + cont_norm_layer: str, + ): + super(CatEmbeddingsAndCont, self).__init__() + + self.column_idx = column_idx + self.embed_input = embed_input + self.continuous_cols = continuous_cols + + # Embeddings: val + 1 because 0 is reserved for padding/unseen cateogories. + if self.embed_input is not None: + self.embed_layers = nn.ModuleDict( + { + "emb_layer_" + col: nn.Embedding(val + 1, dim, padding_idx=0) + for col, val, dim in self.embed_input + } + ) + self.embedding_dropout = nn.Dropout(embed_dropout) + self.emb_out_dim: int = int( + np.sum([embed[2] for embed in self.embed_input]) + ) + else: + self.emb_out_dim = 0 + + # Continuous + if self.continuous_cols is not None: + self.cont_idx = [self.column_idx[col] for col in self.continuous_cols] + self.cont_out_dim: int = len(self.continuous_cols) + if cont_norm_layer == "batchnorm": + self.cont_norm: NormLayers = nn.BatchNorm1d(self.cont_out_dim) + elif cont_norm_layer == "layernorm": + self.cont_norm = nn.LayerNorm(self.cont_out_dim) + else: + self.cont_norm = nn.Identity() + else: + self.cont_out_dim = 0 + + self.output_dim = self.emb_out_dim + self.cont_out_dim + + def forward(self, X: Tensor) -> Tuple[Tensor, Any]: + if self.embed_input is not None: + embed = [ + self.embed_layers["emb_layer_" + col](X[:, self.column_idx[col]].long()) + for col, _, _ in self.embed_input + ] + x_emb = torch.cat(embed, 1) + x_emb = self.embedding_dropout(x_emb) + else: + x_emb = None + if self.continuous_cols is not None: + x_cont = self.cont_norm((X[:, self.cont_idx].float())) + else: + x_cont = None + + return x_emb, x_cont + + class MLP(nn.Module): def __init__( self, @@ -95,7 +168,7 @@ class TabMlp(nn.Module): ---------- column_idx: Dict Dict containing the index of the columns that will be passed through - the TabMlp model. Required to slice the tensors. e.g. {'education': + the ``TabMlp`` model. Required to slice the tensors. e.g. {'education': 0, 'relationship': 1, 'workclass': 2, ...} embed_input: List, Optional, default = None List of Tuples with the column name, number of unique values and @@ -111,7 +184,7 @@ class TabMlp(nn.Module): List with the number of neurons per dense layer in the mlp. mlp_activation: str, default = "relu" Activation function for the dense layers of the MLP. Currently - 'relu', 'leaky_relu' and 'gelu' are supported + ``tanh``, ``relu``, ``leaky_relu`` and ``gelu`` are supported mlp_dropout: float or List, default = 0.1 float or List of floats with the dropout between the dense layers. e.g: [0.5,0.5] @@ -128,13 +201,11 @@ class TabMlp(nn.Module): Attributes ---------- - cont_norm: ``nn.Module`` - continuous normalization layer + cat_embed_and_cont: ``nn.Module`` + This is the module that processes the categorical and continuous columns tab_mlp: ``nn.Sequential`` mlp model that will receive the concatenation of the embeddings and the continuous columns - embed_layers: ``nn.ModuleDict`` - ``ModuleDict`` with the embeddings set up output_dim: int The output dimension of the model. This is a required attribute neccesary to build the WideDeep class @@ -169,15 +240,15 @@ class TabMlp(nn.Module): super(TabMlp, self).__init__() self.column_idx = column_idx + self.embed_input = embed_input self.mlp_hidden_dims = mlp_hidden_dims + self.embed_dropout = embed_dropout + self.continuous_cols = continuous_cols + self.cont_norm_layer = cont_norm_layer self.mlp_activation = mlp_activation self.mlp_dropout = mlp_dropout self.mlp_batchnorm = mlp_batchnorm self.mlp_linear_first = mlp_linear_first - self.embed_input = embed_input - self.embed_dropout = embed_dropout - self.continuous_cols = continuous_cols - self.cont_norm_layer = cont_norm_layer if self.mlp_activation not in allowed_activations: raise ValueError( @@ -187,35 +258,17 @@ class TabMlp(nn.Module): ) ) - # Embeddings: val + 1 because 0 is reserved for padding/unseen cateogories. - if self.embed_input is not None: - self.embed_layers = nn.ModuleDict( - { - "emb_layer_" + col: nn.Embedding(val + 1, dim, padding_idx=0) - for col, val, dim in self.embed_input - } - ) - self.embedding_dropout = nn.Dropout(embed_dropout) - emb_inp_dim = np.sum([embed[2] for embed in self.embed_input]) - else: - emb_inp_dim = 0 # type: ignore[assignment] - - # Continuous - if self.continuous_cols is not None: - self.cont_idx = [self.column_idx[col] for col in self.continuous_cols] - cont_inp_dim = len(self.continuous_cols) - if self.cont_norm_layer == "batchnorm": - self.cont_norm: NormLayers = nn.BatchNorm1d(cont_inp_dim) - elif self.cont_norm_layer == "layernorm": - self.cont_norm = nn.LayerNorm(cont_inp_dim) - else: - self.cont_norm = nn.Identity() - else: - cont_inp_dim = 0 + self.cat_embed_and_cont = CatEmbeddingsAndCont( + column_idx, + embed_input, + embed_dropout, + continuous_cols, + cont_norm_layer, + ) # MLP - input_dim = emb_inp_dim + cont_inp_dim - mlp_hidden_dims = [input_dim] + mlp_hidden_dims # type: ignore[assignment, operator] + mlp_input_dim = self.cat_embed_and_cont.output_dim + mlp_hidden_dims = [mlp_input_dim] + mlp_hidden_dims self.tab_mlp = MLP( mlp_hidden_dims, mlp_activation, @@ -228,18 +281,13 @@ class TabMlp(nn.Module): # the output_dim attribute will be used as input_dim when "merging" the models self.output_dim = mlp_hidden_dims[-1] - def forward(self, X: Tensor) -> Tensor: # type: ignore + def forward(self, X: Tensor) -> Tensor: r"""Forward pass that concatenates the continuous features with the embeddings. The result is then passed through a series of dense layers """ - if self.embed_input is not None: - embed = [ - self.embed_layers["emb_layer_" + col](X[:, self.column_idx[col]].long()) - for col, _, _ in self.embed_input - ] - x = torch.cat(embed, 1) - x = self.embedding_dropout(x) - if self.continuous_cols is not None: - x_cont = self.cont_norm((X[:, self.cont_idx].float())) - x = torch.cat([x, x_cont], 1) if self.embed_input is not None else x_cont + x_emb, x_cont = self.cat_embed_and_cont(X) + if x_emb is not None: + x = x_emb + if x_cont is not None: + x = torch.cat([x, x_cont], 1) if x_emb is not None else x_cont return self.tab_mlp(x) diff --git a/pytorch_widedeep/models/tab_resnet.py b/pytorch_widedeep/models/tab_resnet.py index c9ab3c83c40f9c74f74429aee684b2eaa9291e2a..e03ac4d3bc2b798bf2fbb394307d7aadc3b74f9f 100644 --- a/pytorch_widedeep/models/tab_resnet.py +++ b/pytorch_widedeep/models/tab_resnet.py @@ -1,15 +1,15 @@ from collections import OrderedDict -import numpy as np import torch from torch import nn from torch.nn import Module from pytorch_widedeep.wdtypes import * # noqa: F403 -from pytorch_widedeep.models.tab_mlp import MLP +from pytorch_widedeep.models.tab_mlp import MLP, CatEmbeddingsAndCont class BasicBlock(nn.Module): + # inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L37 def __init__(self, inp: int, out: int, dropout: float = 0.0, resize: Module = None): super(BasicBlock, self).__init__() @@ -86,20 +86,17 @@ class TabResnet(nn.Module): r"""Defines a so-called ``TabResnet`` model that can be used as the ``deeptabular`` component of a Wide & Deep model. - This class combines embedding representations of the categorical - features with numerical (aka continuous) features. These are then - passed through a series of Resnet blocks. See - ``pytorch_widedeep.models.tab_resnet.BasicBlock`` for details - on the structure of each block. - - .. note:: Unlike ``TabMlp``, ``TabResnet`` assumes that there are always - categorical columns + This class combines embedding representations of the categorical features + with numerical (aka continuous) features. These are then passed through a + series of Resnet blocks. See + :obj:`pytorch_widedeep.models.tab_resnet.BasicBlock` for details on the + structure of each block. Parameters ---------- column_idx: Dict Dict containing the index of the columns that will be passed through - the TabMlp model. Required to slice the tensors. e.g. {'education': + the ``Resnet`` model. Required to slice the tensors. e.g. {'education': 0, 'relationship': 1, 'workclass': 2, ...} embed_input: List List of Tuples with the column name, number of unique values and @@ -112,11 +109,11 @@ class TabResnet(nn.Module): Type of normalization layer applied to the continuous features. Options are: 'layernorm', 'batchnorm' or None. concat_cont_first: bool, default = True - Boolean indicating whether the continuum columns will be - concatenated with the Embeddings and then passed through the - Resnet blocks (``True``) or, alternatively, will be concatenated - with the result of passing the embeddings through the Resnet - Blocks (``False``) + If ``True`` the continuum columns will be concatenated with the + Categorical Embeddings and then passed through the Resnet blocks. If + ``False``, the Categorical Embeddings will be passed through the + Resnet blocks and then the output of the Resnet blocks will be + concatenated with the continuous features. blocks_dims: List, default = [200, 100, 100] List of integers that define the input and output units of each block. For example: [200, 100, 100] will generate 2 blocks. The first will @@ -128,12 +125,13 @@ class TabResnet(nn.Module): Block's `"internal"` dropout. This dropout is applied to the first of the two dense layers that comprise each ``BasicBlock``. mlp_hidden_dims: List, Optional, default = None - List with the number of neurons per dense layer in the mlp. e.g: + List with the number of neurons per dense layer in the MLP. e.g: [64, 32]. If ``None`` the output of the Resnet Blocks will be connected directly to the output neuron(s), i.e. using a MLP is optional. mlp_activation: str, default = "relu" - MLP activation function. 'relu', 'leaky_relu' and 'gelu' are supported + Activation function for the dense layers of the MLP. Currently + ``tanh``, ``relu``, ``leaky_relu`` and ``gelu`` are supported mlp_dropout: float, default = 0.1 float with the dropout between the dense layers of the MLP. mlp_batchnorm: bool, default = False @@ -149,13 +147,11 @@ class TabResnet(nn.Module): Attributes ---------- - embed_layers: ``nn.ModuleDict`` - ``ModuleDict`` with the embeddings setup + cat_embed_and_cont: ``nn.Module`` + This is the module that processes the categorical and continuous columns dense_resnet: ``nn.Sequential`` deep dense Resnet model that will receive the concatenation of the embeddings and the continuous columns - cont_norm: ``nn.Module`` - continuous normalization layer tab_resnet_mlp: ``nn.Sequential`` if ``mlp_hidden_dims`` is ``True``, this attribute will be an mlp model that will receive: @@ -188,7 +184,7 @@ class TabResnet(nn.Module): def __init__( self, column_idx: Dict[str, int], - embed_input: List[Tuple[str, int, int]], + embed_input: Optional[List[Tuple[str, int, int]]] = None, embed_dropout: float = 0.1, continuous_cols: Optional[List[str]] = None, cont_norm_layer: str = "batchnorm", @@ -204,6 +200,16 @@ class TabResnet(nn.Module): ): super(TabResnet, self).__init__() + if len(blocks_dims) < 2: + raise ValueError( + "'blocks' must contain at least two elements, e.g. [256, 128]" + ) + + if not concat_cont_first and embed_input is None: + raise ValueError( + "If 'concat_cont_first = False' 'embed_input' must be not 'None'" + ) + self.column_idx = column_idx self.embed_input = embed_input self.embed_dropout = embed_dropout @@ -218,43 +224,26 @@ class TabResnet(nn.Module): self.mlp_batchnorm_last = mlp_batchnorm_last self.mlp_linear_first = mlp_linear_first - if len(self.blocks_dims) < 2: - raise ValueError( - "'blocks' must contain at least two elements, e.g. [256, 128]" - ) - - # Embeddings: val + 1 because 0 is reserved for padding/unseen cateogories. - self.embed_layers = nn.ModuleDict( - { - "emb_layer_" + col: nn.Embedding(val + 1, dim, padding_idx=0) - for col, val, dim in self.embed_input - } + self.cat_embed_and_cont = CatEmbeddingsAndCont( + column_idx, + embed_input, + embed_dropout, + continuous_cols, + cont_norm_layer, ) - self.embedding_dropout = nn.Dropout(embed_dropout) - emb_inp_dim = np.sum([embed[2] for embed in self.embed_input]) - - # Continuous - if self.continuous_cols is not None: - self.cont_idx = [self.column_idx[col] for col in self.continuous_cols] - cont_inp_dim = len(self.continuous_cols) - if self.cont_norm_layer == "batchnorm": - self.cont_norm: NormLayers = nn.BatchNorm1d(cont_inp_dim) - elif self.cont_norm_layer == "layernorm": - self.cont_norm = nn.LayerNorm(cont_inp_dim) - else: - self.cont_norm = nn.Identity() - else: - cont_inp_dim = 0 + + emb_out_dim = self.cat_embed_and_cont.emb_out_dim + cont_out_dim = self.cat_embed_and_cont.cont_out_dim # DenseResnet if self.concat_cont_first: - dense_resnet_input_dim = emb_inp_dim + cont_inp_dim + dense_resnet_input_dim = emb_out_dim + cont_out_dim self.output_dim = blocks_dims[-1] else: - dense_resnet_input_dim = emb_inp_dim - self.output_dim = cont_inp_dim + blocks_dims[-1] + dense_resnet_input_dim = emb_out_dim + self.output_dim = cont_out_dim + blocks_dims[-1] self.tab_resnet_blks = DenseResnet( - dense_resnet_input_dim, blocks_dims, blocks_dropout # type: ignore[arg-type] + dense_resnet_input_dim, blocks_dims, blocks_dropout ) # MLP @@ -262,7 +251,7 @@ class TabResnet(nn.Module): if self.concat_cont_first: mlp_input_dim = blocks_dims[-1] else: - mlp_input_dim = cont_inp_dim + blocks_dims[-1] + mlp_input_dim = cont_out_dim + blocks_dims[-1] mlp_hidden_dims = [mlp_input_dim] + mlp_hidden_dims self.tab_resnet_mlp = MLP( mlp_hidden_dims, @@ -274,26 +263,21 @@ class TabResnet(nn.Module): ) self.output_dim = mlp_hidden_dims[-1] - def forward(self, X: Tensor) -> Tensor: # type: ignore + def forward(self, X: Tensor) -> Tensor: r"""Forward pass that concatenates the continuous features with the embeddings. The result is then passed through a series of dense Resnet blocks""" - embed = [ - self.embed_layers["emb_layer_" + col](X[:, self.column_idx[col]].long()) - for col, _, _ in self.embed_input - ] - x = torch.cat(embed, 1) - x = self.embedding_dropout(x) - - if self.continuous_cols is not None: - x_cont = self.cont_norm((X[:, self.cont_idx].float())) + + x_emb, x_cont = self.cat_embed_and_cont(X) + + if x_cont is not None: if self.concat_cont_first: - x = torch.cat([x, x_cont], 1) + x = torch.cat([x_emb, x_cont], 1) if x_emb is not None else x_cont out = self.tab_resnet_blks(x) else: - out = torch.cat([self.tab_resnet_blks(x), x_cont], 1) + out = torch.cat([self.tab_resnet_blks(x_emb), x_cont], 1) else: - out = self.tab_resnet_blks(x) + out = self.tab_resnet_blks(x_emb) if self.mlp_hidden_dims is not None: out = self.tab_resnet_mlp(out) diff --git a/pytorch_widedeep/models/tabnet/_layers.py b/pytorch_widedeep/models/tabnet/_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..079f4e86549b7cb8eff989603c68796a2bfa9aec --- /dev/null +++ b/pytorch_widedeep/models/tabnet/_layers.py @@ -0,0 +1,358 @@ +""" +Most of the code here is a direct copy and paste from the fantastic tabnet +implementation here: https://github.com/dreamquark-ai/tabnet + +Therefore, ALL CREDIT TO THE DREAMQUARK-AI TEAM +----------------------------------------------- + +Here I simply adapted what I needed the TabNet to work within pytorch-widedeep +""" + +import warnings + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from pytorch_widedeep.wdtypes import * # noqa: F403 +from pytorch_widedeep.models.tabnet import sparsemax + + +def initialize_non_glu(module, input_dim: int, output_dim: int): + gain_value = np.sqrt((input_dim + output_dim) / np.sqrt(4 * input_dim)) + torch.nn.init.xavier_normal_(module.weight, gain=gain_value) + return + + +def initialize_glu(module, input_dim: int, output_dim: int): + gain_value = np.sqrt((input_dim + output_dim) / np.sqrt(input_dim)) + torch.nn.init.xavier_normal_(module.weight, gain=gain_value) + return + + +class GBN(torch.nn.Module): + """ + Ghost Batch Normalization + https://arxiv.org/abs/1705.08741 + """ + + def __init__( + self, input_dim: int, virtual_batch_size: int = 128, momentum: float = 0.01 + ): + super(GBN, self).__init__() + self.virtual_batch_size = virtual_batch_size + self.bn = nn.BatchNorm1d(input_dim, momentum=momentum) + + def forward(self, X: Tensor) -> Tensor: + chunks = X.chunk(int(np.ceil(X.shape[0] / self.virtual_batch_size)), 0) + res = [self.bn(x_) for x_ in chunks] + return torch.cat(res, dim=0) + + +class GLU_Layer(nn.Module): + def __init__( + self, + input_dim: int, + output_dim: int, + dropout: float, + fc: nn.Module = None, + ghost_bn: bool = True, + virtual_batch_size: int = 128, + momentum: float = 0.02, + ): + super(GLU_Layer, self).__init__() + + if fc: + self.fc = fc + else: + self.fc = nn.Linear(input_dim, 2 * output_dim, bias=False) + initialize_glu(self.fc, input_dim, 2 * output_dim) + + if ghost_bn: + self.bn: Union[GBN, nn.BatchNorm1d] = GBN( + 2 * output_dim, virtual_batch_size=virtual_batch_size, momentum=momentum + ) + else: + self.bn = nn.BatchNorm1d(2 * output_dim, momentum=momentum) + + self.dp = nn.Dropout(dropout) + + def forward(self, X: Tensor) -> Tensor: + return self.dp(F.glu(self.bn(self.fc(X)))) + + +class GLU_Block(nn.Module): + def __init__( + self, + input_dim: int, + output_dim: int, + dropout: float, + n_glu: int = 2, + first: bool = False, + shared_layers: nn.ModuleList = None, + ghost_bn: bool = True, + virtual_batch_size: int = 128, + momentum: float = 0.02, + ): + super(GLU_Block, self).__init__() + self.first = first + + if (shared_layers is not None) and (n_glu != len(shared_layers)): + self.n_glu = len(shared_layers) + warnings.warn( + "If 'shared_layers' is nor None, 'n_glu' must be equal to the number of shared_layers." + "Got n_glu = {} and n shared_layers = {}. 'n_glu' has been set to {}".format( + n_glu, len(shared_layers), len(shared_layers) + ), + UserWarning, + ) + else: + self.n_glu = n_glu + + glu_dim = [input_dim] + [output_dim] * self.n_glu + self.glu_layers = nn.ModuleList() + for i in range(self.n_glu): + fc = shared_layers[i] if shared_layers else None + self.glu_layers.append( + GLU_Layer( + glu_dim[i], + glu_dim[i + 1], + dropout, + fc=fc, + ghost_bn=ghost_bn, + virtual_batch_size=virtual_batch_size, + momentum=momentum, + ) + ) + + def forward(self, X: Tensor) -> Tensor: + scale = torch.sqrt(torch.FloatTensor([0.5]).to(X.device)) + + if self.first: # the first layer of the block has no scale multiplication + x = self.glu_layers[0](X) + layers_left = range(1, self.n_glu) + else: + x = nn.Identity()(X) + layers_left = range(self.n_glu) + + for glu_id in layers_left: + x = torch.add(x, self.glu_layers[glu_id](x)) * scale + + return x + + +class FeatTransformer(nn.Module): + def __init__( + self, + input_dim: int, + output_dim: int, + dropout: float, + shared_layers: nn.ModuleList, + n_glu_step_dependent: int, + ghost_bn=True, + virtual_batch_size=128, + momentum=0.02, + ): + super(FeatTransformer, self).__init__() + + params = { + "ghost_bn": ghost_bn, + "virtual_batch_size": virtual_batch_size, + "momentum": momentum, + } + + self.shared = GLU_Block( + input_dim, + output_dim, + dropout, + n_glu=len(shared_layers), + first=True, + shared_layers=shared_layers, + **params + ) + + self.step_dependent = GLU_Block( + output_dim, + output_dim, + dropout, + n_glu=n_glu_step_dependent, + first=False, + **params + ) + + def forward(self, X: Tensor) -> Tensor: + return self.step_dependent(self.shared(X)) + + +class AttentiveTransformer(nn.Module): + def __init__( + self, + input_dim: int, + output_dim: int, + mask_type: str = "sparsemax", + ghost_bn=True, + virtual_batch_size=128, + momentum=0.02, + ): + + super(AttentiveTransformer, self).__init__() + self.fc = nn.Linear(input_dim, output_dim, bias=False) + initialize_non_glu(self.fc, input_dim, output_dim) + if ghost_bn: + self.bn: Union[GBN, nn.BatchNorm1d] = GBN( + output_dim, virtual_batch_size=virtual_batch_size, momentum=momentum + ) + else: + self.bn = nn.BatchNorm1d(output_dim, momentum=momentum) + + if mask_type == "sparsemax": + self.mask: Union[Sparsemax, Entmax15] = sparsemax.Sparsemax(dim=-1) + elif mask_type == "entmax": + self.mask = sparsemax.Entmax15(dim=-1) + else: + raise NotImplementedError( + "Please choose either 'sparsemax' or 'entmax' as masktype" + ) + + def forward(self, priors: Tensor, processed_feat: Tensor) -> Tensor: + x = self.bn(self.fc(processed_feat)) + x = torch.mul(x, priors) + return self.mask(x) + + +class TabNetEncoder(nn.Module): + def __init__( + self, + input_dim: int, + n_steps: int = 3, + step_dim: int = 8, + attn_dim: int = 8, + dropout: float = 0.0, + n_glu_step_dependent: int = 2, + n_glu_shared: int = 2, + ghost_bn: bool = True, + virtual_batch_size: int = 128, + momentum: float = 0.02, + gamma: float = 1.3, + epsilon: float = 1e-15, + mask_type: str = "sparsemax", + ): + super(TabNetEncoder, self).__init__() + + self.input_dim = input_dim + self.n_steps = n_steps + self.step_dim = step_dim + self.attn_dim = attn_dim + self.gamma = gamma + self.epsilon = epsilon + + self.initial_bn = nn.BatchNorm1d(input_dim, momentum=0.01) + + params = { + "ghost_bn": ghost_bn, + "virtual_batch_size": virtual_batch_size, + "momentum": momentum, + } + + shared_layers = nn.ModuleList() + for i in range(n_glu_shared): + if i == 0: + shared_layers.append( + nn.Linear(input_dim, 2 * (step_dim + attn_dim), bias=False) + ) + else: + shared_layers.append( + nn.Linear( + step_dim + attn_dim, 2 * (step_dim + attn_dim), bias=False + ) + ) + + self.initial_splitter = FeatTransformer( + input_dim, + step_dim + attn_dim, + dropout, + shared_layers, + n_glu_step_dependent, + **params + ) + + self.feat_transformers = nn.ModuleList() + self.attn_transformers = nn.ModuleList() + for step in range(n_steps): + feat_transformer = FeatTransformer( + input_dim, + step_dim + attn_dim, + dropout, + shared_layers, + n_glu_step_dependent, + **params + ) + attn_transformer = AttentiveTransformer( + attn_dim, input_dim, mask_type, **params + ) + self.feat_transformers.append(feat_transformer) + self.attn_transformers.append(attn_transformer) + + def forward(self, X: Tensor) -> Tuple[List[Tensor], Tensor]: + x = self.initial_bn(X) + + # P[n_step = 0] is initialized as all ones, 1^(B×D) + prior = torch.ones(x.shape).to(x.device) + + # sparsity regularization + M_loss = torch.FloatTensor([0.0]).to(x.device) + + # split block + attn = self.initial_splitter(x)[:, self.step_dim :] + + steps_output = [] + for step in range(self.n_steps): + # learnable mask: M[i] = sparsemax(prior[i − 1] · hi(a[i − 1])) + # where hi = FC + BN + M = self.attn_transformers[step](prior, attn) + + # update prior: P[i] = \prod_{i}^{j=1} (γ − M[j]) + prior = torch.mul(self.gamma - M, prior) + + # sparsity regularization + M_loss += torch.mean( + torch.sum(torch.mul(M, torch.log(M + self.epsilon)), dim=1) + ) + + # update attention and d_out + masked_x = torch.mul(M, x) + out = self.feat_transformers[step](masked_x) + attn = out[:, self.step_dim :] + d_out = nn.ReLU()(out[:, : self.step_dim]) + steps_output.append(d_out) + + M_loss /= self.n_steps # type: ignore[has-type] + + return steps_output, M_loss + + def forward_masks(self, X: Tensor) -> Tuple[Tensor, Dict[int, Tensor]]: + x = self.initial_bn(X) + + prior = torch.ones(x.shape).to(x.device) + M_explain = torch.zeros(x.shape).to(x.device) + attn = self.initial_splitter(x)[:, self.step_dim :] + masks = {} + + for step in range(self.n_steps): + M = self.attn_transformers[step](prior, attn) + masks[step] = M + + prior = torch.mul(self.gamma - M, prior) + + masked_x = torch.mul(M, x) + out = self.feat_transformers[step](masked_x) + attn = out[:, self.step_dim :] + # 'decision contribution' in the paper + d_out = nn.ReLU()(out[:, : self.step_dim]) + + # aggregate decision contribution + agg_decision_contrib = torch.sum(d_out, dim=1) + M_explain += torch.mul(M, agg_decision_contrib.unsqueeze(dim=1)) + + return M_explain, masks diff --git a/pytorch_widedeep/models/tabnet/tab_net_utils.py b/pytorch_widedeep/models/tabnet/_utils.py similarity index 96% rename from pytorch_widedeep/models/tabnet/tab_net_utils.py rename to pytorch_widedeep/models/tabnet/_utils.py index 6754ba3020b776e36d418c2620eada0d3804e047..79f62488f5ebddfd5d81afa5049c79d394979d69 100644 --- a/pytorch_widedeep/models/tabnet/tab_net_utils.py +++ b/pytorch_widedeep/models/tabnet/_utils.py @@ -17,7 +17,7 @@ def create_explain_matrix(model: WideDeep) -> csc_matrix: Examples -------- >>> from pytorch_widedeep.models import TabNet, WideDeep - >>> from pytorch_widedeep.models.tabnet.tab_net_utils import create_explain_matrix + >>> from pytorch_widedeep.models.tabnet._utils import create_explain_matrix >>> embed_input = [("a", 4, 2), ("b", 4, 2), ("c", 4, 2)] >>> cont_cols = ["d", "e"] >>> column_idx = {k: v for v, k in enumerate(["a", "b", "c", "d", "e"])} diff --git a/pytorch_widedeep/models/tabnet/tab_net.py b/pytorch_widedeep/models/tabnet/tab_net.py index 4e128314bd77fc112b4f6212baee75acb41b7131..31ecc547020a2471dc3e0a8ab1543f72a1cb6874 100644 --- a/pytorch_widedeep/models/tabnet/tab_net.py +++ b/pytorch_widedeep/models/tabnet/tab_net.py @@ -1,420 +1,18 @@ -""" -Most of the code here is a direct copy and paste from the fantastic tabnet -implementation here: https://github.com/dreamquark-ai/tabnet - -Therefore, ALL CREDIT TO THE DREAMQUARK-AI TEAM ------------------------------------------------ - -Here I simply adapted what I needed the TabNet to work within pytorch-widedeep -""" - -import warnings - -import numpy as np import torch -import torch.nn.functional as F from torch import nn from pytorch_widedeep.wdtypes import * # noqa: F403 -from pytorch_widedeep.models.tabnet import sparsemax - - -def initialize_non_glu(module, input_dim: int, output_dim: int): - gain_value = np.sqrt((input_dim + output_dim) / np.sqrt(4 * input_dim)) - torch.nn.init.xavier_normal_(module.weight, gain=gain_value) - return - - -def initialize_glu(module, input_dim: int, output_dim: int): - gain_value = np.sqrt((input_dim + output_dim) / np.sqrt(input_dim)) - torch.nn.init.xavier_normal_(module.weight, gain=gain_value) - return - - -class GBN(torch.nn.Module): - """ - Ghost Batch Normalization - https://arxiv.org/abs/1705.08741 - """ - - def __init__( - self, input_dim: int, virtual_batch_size: int = 128, momentum: float = 0.01 - ): - super(GBN, self).__init__() - self.virtual_batch_size = virtual_batch_size - self.bn = nn.BatchNorm1d(input_dim, momentum=momentum) - - def forward(self, X: Tensor) -> Tensor: - chunks = X.chunk(int(np.ceil(X.shape[0] / self.virtual_batch_size)), 0) - res = [self.bn(x_) for x_ in chunks] - return torch.cat(res, dim=0) - - -class GLU_Layer(nn.Module): - def __init__( - self, - input_dim: int, - output_dim: int, - dropout: float, - fc: nn.Module = None, - ghost_bn: bool = True, - virtual_batch_size: int = 128, - momentum: float = 0.02, - ): - super(GLU_Layer, self).__init__() - - if fc: - self.fc = fc - else: - self.fc = nn.Linear(input_dim, 2 * output_dim, bias=False) - initialize_glu(self.fc, input_dim, 2 * output_dim) - - if ghost_bn: - self.bn: Union[GBN, nn.BatchNorm1d] = GBN( - 2 * output_dim, virtual_batch_size=virtual_batch_size, momentum=momentum - ) - else: - self.bn = nn.BatchNorm1d(2 * output_dim, momentum=momentum) - - self.dp = nn.Dropout(dropout) - - def forward(self, X: Tensor) -> Tensor: - return self.dp(F.glu(self.bn(self.fc(X)))) - - -class GLU_Block(nn.Module): - def __init__( - self, - input_dim: int, - output_dim: int, - dropout: float, - n_glu: int = 2, - first: bool = False, - shared_layers: nn.ModuleList = None, - ghost_bn: bool = True, - virtual_batch_size: int = 128, - momentum: float = 0.02, - ): - super(GLU_Block, self).__init__() - self.first = first - - if (shared_layers is not None) and (n_glu != len(shared_layers)): - self.n_glu = len(shared_layers) - warnings.warn( - "If 'shared_layers' is nor None, 'n_glu' must be equal to the number of shared_layers." - "Got n_glu = {} and n shared_layers = {}. 'n_glu' has been set to {}".format( - n_glu, len(shared_layers), len(shared_layers) - ), - UserWarning, - ) - else: - self.n_glu = n_glu - - glu_dim = [input_dim] + [output_dim] * self.n_glu - self.glu_layers = nn.ModuleList() - for i in range(self.n_glu): - fc = shared_layers[i] if shared_layers else None - self.glu_layers.append( - GLU_Layer( - glu_dim[i], - glu_dim[i + 1], - dropout, - fc=fc, - ghost_bn=ghost_bn, - virtual_batch_size=virtual_batch_size, - momentum=momentum, - ) - ) - - def forward(self, X: Tensor) -> Tensor: - scale = torch.sqrt(torch.FloatTensor([0.5]).to(X.device)) - - if self.first: # the first layer of the block has no scale multiplication - x = self.glu_layers[0](X) - layers_left = range(1, self.n_glu) - else: - x = nn.Identity()(X) - layers_left = range(self.n_glu) - - for glu_id in layers_left: - x = torch.add(x, self.glu_layers[glu_id](x)) * scale - - return x - - -class FeatTransformer(nn.Module): - def __init__( - self, - input_dim: int, - output_dim: int, - dropout: float, - shared_layers: nn.ModuleList, - n_glu_step_dependent: int, - ghost_bn=True, - virtual_batch_size=128, - momentum=0.02, - ): - super(FeatTransformer, self).__init__() - - params = { - "ghost_bn": ghost_bn, - "virtual_batch_size": virtual_batch_size, - "momentum": momentum, - } - - self.shared = GLU_Block( - input_dim, - output_dim, - dropout, - n_glu=len(shared_layers), - first=True, - shared_layers=shared_layers, - **params - ) - - self.step_dependent = GLU_Block( - output_dim, - output_dim, - dropout, - n_glu=n_glu_step_dependent, - first=False, - **params - ) - - def forward(self, X: Tensor) -> Tensor: - return self.step_dependent(self.shared(X)) - - -class AttentiveTransformer(nn.Module): - def __init__( - self, - input_dim: int, - output_dim: int, - mask_type: str = "sparsemax", - ghost_bn=True, - virtual_batch_size=128, - momentum=0.02, - ): - - super(AttentiveTransformer, self).__init__() - self.fc = nn.Linear(input_dim, output_dim, bias=False) - initialize_non_glu(self.fc, input_dim, output_dim) - if ghost_bn: - self.bn: Union[GBN, nn.BatchNorm1d] = GBN( - output_dim, virtual_batch_size=virtual_batch_size, momentum=momentum - ) - else: - self.bn = nn.BatchNorm1d(output_dim, momentum=momentum) - - if mask_type == "sparsemax": - self.mask: Union[Sparsemax, Entmax15] = sparsemax.Sparsemax(dim=-1) - elif mask_type == "entmax": - self.mask = sparsemax.Entmax15(dim=-1) - else: - raise NotImplementedError( - "Please choose either 'sparsemax' or 'entmax' as masktype" - ) - - def forward(self, priors: Tensor, processed_feat: Tensor) -> Tensor: - x = self.bn(self.fc(processed_feat)) - x = torch.mul(x, priors) - return self.mask(x) - - -class TabNetEncoder(nn.Module): - def __init__( - self, - input_dim: int, - n_steps: int = 3, - step_dim: int = 8, - attn_dim: int = 8, - dropout: float = 0.0, - n_glu_step_dependent: int = 2, - n_glu_shared: int = 2, - ghost_bn: bool = True, - virtual_batch_size: int = 128, - momentum: float = 0.02, - gamma: float = 1.3, - epsilon: float = 1e-15, - mask_type: str = "sparsemax", - ): - super(TabNetEncoder, self).__init__() - - self.input_dim = input_dim - self.n_steps = n_steps - self.step_dim = step_dim - self.attn_dim = attn_dim - self.gamma = gamma - self.epsilon = epsilon - - self.initial_bn = nn.BatchNorm1d(input_dim, momentum=0.01) - - params = { - "ghost_bn": ghost_bn, - "virtual_batch_size": virtual_batch_size, - "momentum": momentum, - } - - shared_layers = nn.ModuleList() - for i in range(n_glu_shared): - if i == 0: - shared_layers.append( - nn.Linear(input_dim, 2 * (step_dim + attn_dim), bias=False) - ) - else: - shared_layers.append( - nn.Linear( - step_dim + attn_dim, 2 * (step_dim + attn_dim), bias=False - ) - ) - - self.initial_splitter = FeatTransformer( - input_dim, - step_dim + attn_dim, - dropout, - shared_layers, - n_glu_step_dependent, - **params - ) - - self.feat_transformers = nn.ModuleList() - self.attn_transformers = nn.ModuleList() - for step in range(n_steps): - feat_transformer = FeatTransformer( - input_dim, - step_dim + attn_dim, - dropout, - shared_layers, - n_glu_step_dependent, - **params - ) - attn_transformer = AttentiveTransformer( - attn_dim, input_dim, mask_type, **params - ) - self.feat_transformers.append(feat_transformer) - self.attn_transformers.append(attn_transformer) - - def forward(self, X: Tensor) -> Tuple[List[Tensor], Tensor]: - x = self.initial_bn(X) - - # P[n_step = 0] is initialized as all ones, 1^(B×D) - prior = torch.ones(x.shape).to(x.device) - - # sparsity regularization - M_loss = torch.FloatTensor([0.0]).to(x.device) - - # split block - attn = self.initial_splitter(x)[:, self.step_dim :] - - steps_output = [] - for step in range(self.n_steps): - # learnable mask: M[i] = sparsemax(prior[i − 1] · hi(a[i − 1])) - # where hi = FC + BN - M = self.attn_transformers[step](prior, attn) - - # update prior: P[i] = \prod_{i}^{j=1} (γ − M[j]) - prior = torch.mul(self.gamma - M, prior) - - # sparsity regularization - M_loss += torch.mean( - torch.sum(torch.mul(M, torch.log(M + self.epsilon)), dim=1) - ) - - # update attention and d_out - masked_x = torch.mul(M, x) - out = self.feat_transformers[step](masked_x) - attn = out[:, self.step_dim :] - d_out = nn.ReLU()(out[:, : self.step_dim]) - steps_output.append(d_out) - - M_loss /= self.n_steps # type: ignore[has-type] - - return steps_output, M_loss - - def forward_masks(self, X: Tensor) -> Tuple[Tensor, Dict[int, Tensor]]: - x = self.initial_bn(X) - - prior = torch.ones(x.shape).to(x.device) - M_explain = torch.zeros(x.shape).to(x.device) - attn = self.initial_splitter(x)[:, self.step_dim :] - masks = {} - - for step in range(self.n_steps): - M = self.attn_transformers[step](prior, attn) - masks[step] = M - - prior = torch.mul(self.gamma - M, prior) - - masked_x = torch.mul(M, x) - out = self.feat_transformers[step](masked_x) - attn = out[:, self.step_dim :] - # 'decision contribution' in the paper - d_out = nn.ReLU()(out[:, : self.step_dim]) - - # aggregate decision contribution - agg_decision_contrib = torch.sum(d_out, dim=1) - M_explain += torch.mul(M, agg_decision_contrib.unsqueeze(dim=1)) - - return M_explain, masks - - -class EmbeddingsAndContinuous(nn.Module): - def __init__( - self, - column_idx: Dict[str, int], - embed_input: List[Tuple[str, int, int]], - embed_dropout: float, - continuous_cols: Optional[List[str]], - cont_norm_layer: str, - ): - super(EmbeddingsAndContinuous, self).__init__() - - self.column_idx = column_idx - self.embed_input = embed_input - self.continuous_cols = continuous_cols - self.cont_norm_layer = cont_norm_layer - - # Embeddings: val + 1 because 0 is reserved for padding/unseen cateogories. - self.embed_layers = nn.ModuleDict( - { - "emb_layer_" + col: nn.Embedding(val + 1, dim, padding_idx=0) - for col, val, dim in self.embed_input - } - ) - self.embedding_dropout = nn.Dropout(embed_dropout) - emb_out_dim = np.sum([embed[2] for embed in self.embed_input]) - - # Continuous - if self.continuous_cols is not None: - self.cont_idx = [self.column_idx[col] for col in self.continuous_cols] - cont_out_dim = len(self.continuous_cols) - if self.cont_norm_layer == "batchnorm": - self.cont_norm: NormLayers = nn.BatchNorm1d(cont_out_dim) - if self.cont_norm_layer == "layernorm": - self.cont_norm = nn.LayerNorm(cont_out_dim) - else: - self.cont_norm = nn.Identity() - else: - cont_out_dim = 0 - - self.output_dim: int = emb_out_dim + cont_out_dim # type: ignore[assignment] - - def forward(self, X: Tensor) -> Tensor: - embed = [ - self.embed_layers["emb_layer_" + col](X[:, self.column_idx[col]].long()) - for col, _, _ in self.embed_input - ] - x = torch.cat(embed, 1) - x = self.embedding_dropout(x) - if self.continuous_cols is not None: - x_cont = self.cont_norm((X[:, self.cont_idx].float())) - x = torch.cat([x, x_cont], 1) if self.embed_input is not None else x_cont - return x +from pytorch_widedeep.models.tab_mlp import CatEmbeddingsAndCont +from pytorch_widedeep.models.tabnet._layers import ( + TabNetEncoder, + initialize_non_glu, +) class TabNet(nn.Module): - r"""TabNet model (https://arxiv.org/abs/1908.07442) model that can be used - as the deeptabular component of a Wide & Deep model. + r"""Defines a ``TabNet`` model (https://arxiv.org/abs/1908.07442) model + that can be used as the ``deeptabular`` component of a Wide & Deep + model. The implementation in this library is fully based on that here: https://github.com/dreamquark-ai/tabnet, simply adapted so that it can @@ -448,7 +46,7 @@ class TabNet(nn.Module): attn_dim: int, default = 8 Attention dimension dropout: float, default = 0.0 - GLU block 'internal' dropout + GLU block's internal dropout n_glu_step_dependent: int, default = 2 number of GLU Blocks [FC -> BN -> GLU] that are step dependent n_glu_shared: int, default = 2 @@ -461,13 +59,13 @@ class TabNet(nn.Module): Batch size when using Ghost Batch Normalization momentum: float, default = 0.02 Ghost Batch Normalization's momentum. The dreamquark-ai advises for - very low values. The results in the paper use significantly higher - values. Higher values lead to better results in my experimentations + very low values. However high values are used in the original + publication. During our tests higher values lead to better results gamma: float, default = 1.3 Relaxation parameter in the paper. When gamma = 1, a feature is - enforced to be used only at one decision step and as gamma - increases, more flexibility is provided to use a feature at - multiple decision steps + enforced to be used only at one decision step. As gamma increases, + more flexibility is provided to use a feature at multiple decision + steps epsilon: float, default = 1e-15 Float to avoid log(0). Always keep low mask_type: str, default = "sparsemax" @@ -475,16 +73,14 @@ class TabNet(nn.Module): Attributes ---------- - embed_and_cont: ``nn.ModuleDict`` - ``ModuleDict`` with the embeddings and continuous setup - embed_and_cont_dim: int - embeddings plus continuous dimension - output_dim: int - The output dimension of the model. This is a required attribute - neccesary to build the WideDeep class + cat_embed_and_cont: ``nn.Module`` + This is the module that processes the categorical and continuous columns tabnet_encoder: ``nn.Module`` ``Module`` containing the TabNet encoder. See the `paper `_. + output_dim: int + The output dimension of the model. This is a required attribute + neccesary to build the WideDeep class Example -------- @@ -537,12 +133,17 @@ class TabNet(nn.Module): self.epsilon = epsilon self.mask_type = mask_type - self.embed_and_cont = EmbeddingsAndContinuous( - column_idx, embed_input, embed_dropout, continuous_cols, cont_norm_layer + self.cat_embed_and_cont = CatEmbeddingsAndCont( + column_idx, + embed_input, + embed_dropout, + continuous_cols, + cont_norm_layer, ) - self.embed_and_cont_dim = self.embed_and_cont.output_dim + + self.embed_and_cont_dim = self.cat_embed_and_cont.output_dim self.tabnet_encoder = TabNetEncoder( - self.embed_and_cont.output_dim, + self.embed_and_cont_dim, n_steps, step_dim, attn_dim, @@ -559,13 +160,26 @@ class TabNet(nn.Module): self.output_dim = step_dim def forward(self, X: Tensor) -> Tuple[Tensor, Tensor]: - x = self.embed_and_cont(X) + + x_emb, x_cont = self.cat_embed_and_cont(X) + if x_emb is not None: + x = x_emb + if x_cont is not None: + x = torch.cat([x, x_cont], 1) if x_emb is not None else x_cont + steps_output, M_loss = self.tabnet_encoder(x) res = torch.sum(torch.stack(steps_output, dim=0), dim=0) + return (res, M_loss) def forward_masks(self, X: Tensor) -> Tuple[Tensor, Dict[int, Tensor]]: - x = self.embed_and_cont(X) + + x_emb, x_cont = self.cat_embed_and_cont(X) + if x_emb is not None: + x = x_emb + if x_cont is not None: + x = torch.cat([x, x_cont], 1) if x_emb is not None else x_cont + return self.tabnet_encoder.forward_masks(x) diff --git a/pytorch_widedeep/models/transformers/_attention_layers.py b/pytorch_widedeep/models/transformers/_attention_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..5eb24184517155f833e76e4302dc7a072e0406d5 --- /dev/null +++ b/pytorch_widedeep/models/transformers/_attention_layers.py @@ -0,0 +1,243 @@ +""" +MultiHeadedAttention is inspired by the implementation at +https://github.com/lucidrains +""" + +import math + +import torch +import einops +from torch import nn, einsum + +from pytorch_widedeep.wdtypes import * # noqa: F403 +from pytorch_widedeep.models.tab_mlp import get_activation_fn + + +class PositionwiseFF(nn.Module): + def __init__( + self, + input_dim: int, + dropout: float, + activation: str, + mult: float = 4.0, + ): + super(PositionwiseFF, self).__init__() + ff_hidden_dim = int(input_dim * mult) + self.w_1 = nn.Linear( + input_dim, + ff_hidden_dim * 2 if activation.endswith("glu") else ff_hidden_dim, + ) + self.w_2 = nn.Linear(ff_hidden_dim, input_dim) + self.dropout = nn.Dropout(dropout) + self.activation = get_activation_fn(activation) + + def forward(self, X: Tensor) -> Tensor: + return self.w_2(self.dropout(self.activation(self.w_1(X)))) + + +class NormAdd(nn.Module): + """aka PreNorm""" + + def __init__(self, input_dim: int, dropout: float): + super(NormAdd, self).__init__() + self.dropout = nn.Dropout(dropout) + self.ln = nn.LayerNorm(input_dim) + + def forward(self, X: Tensor, sublayer: nn.Module) -> Tensor: + return X + self.dropout(sublayer(self.ln(X))) + + +class AddNorm(nn.Module): + """aka PosNorm""" + + def __init__(self, input_dim: int, dropout: float): + super(AddNorm, self).__init__() + self.dropout = nn.Dropout(dropout) + self.ln = nn.LayerNorm(input_dim) + + def forward(self, X: Tensor, sublayer: nn.Module) -> Tensor: + return self.ln(X + self.dropout(sublayer(X))) + + +class MultiHeadedAttention(nn.Module): + def __init__( + self, + input_dim: int, + n_heads: int, + use_bias: bool, + dropout: float, + query_dim: Optional[int] = None, + ): + super(MultiHeadedAttention, self).__init__() + + assert input_dim % n_heads == 0, "'input_dim' must be divisible by 'n_heads'" + + self.head_dim = input_dim // n_heads + self.n_heads = n_heads + + self.dropout = nn.Dropout(dropout) + + query_dim = query_dim if query_dim is not None else input_dim + self.q_proj = nn.Linear(query_dim, input_dim, bias=use_bias) + self.kv_proj = nn.Linear(input_dim, input_dim * 2, bias=use_bias) + self.out_proj = ( + nn.Linear(input_dim, query_dim, bias=use_bias) if n_heads > 1 else None + ) + + def forward(self, X_Q: Tensor, X_KV: Optional[Tensor] = None) -> Tensor: + # b: batch size + # s: seq length + # l: target sequence length + # m: used to refer indistinctively to s or l + # h: number of attention heads, + # d: head_dim + q = self.q_proj(X_Q) + X_KV = X_KV if X_KV is not None else X_Q + k, v = self.kv_proj(X_KV).chunk(2, dim=-1) + q, k, v = map( + lambda t: einops.rearrange(t, "b m (h d) -> b h m d", h=self.n_heads), + (q, k, v), + ) + scores = einsum("b h s d, b h l d -> b h s l", q, k) / math.sqrt(self.head_dim) + attn_weights = scores.softmax(dim=-1) + self.attn_weights = attn_weights + attn_weights = self.dropout(attn_weights) + attn_output = einsum("b h s l, b h l d -> b h s d", attn_weights, v) + output = einops.rearrange(attn_output, "b h s d -> b s (h d)", h=self.n_heads) + + if self.out_proj is not None: + output = self.out_proj(output) + + return output + + +class LinearAttention(nn.Module): + def __init__( + self, + input_dim: int, + n_feats: int, + n_heads: int, + use_bias: bool, + dropout: float, + kv_compression_factor: float, + kv_sharing: bool, + ): + super(LinearAttention, self).__init__() + assert input_dim % n_heads == 0, "'input_dim' must be divisible by 'n_heads'" + + self.n_feats = n_feats + self.head_dim = input_dim // n_heads + self.n_heads = n_heads + self.kv_compression_factor = kv_compression_factor + self.share_kv = kv_sharing + + dim_k = int(self.kv_compression_factor * self.n_feats) + + self.dropout = nn.Dropout(dropout) + self.qkv_proj = nn.Linear(input_dim, input_dim * 3, bias=use_bias) + + self.E = nn.init.xavier_uniform_(nn.Parameter(torch.zeros(n_feats, dim_k))) + if not kv_sharing: + self.F = nn.init.xavier_uniform_(nn.Parameter(torch.zeros(n_feats, dim_k))) + else: + self.F = self.E + + self.out_proj = ( + nn.Linear(input_dim, input_dim, bias=use_bias) if n_heads > 1 else None + ) + + def forward(self, X: Tensor) -> Tensor: + # b: batch size + # s: seq length + # h: number of attention heads, + # i: input dim + # k: k dim + # d: head dim + q, k, v = self.qkv_proj(X).chunk(3, dim=-1) + + q = einops.rearrange(q, "b s (h d) -> b h s d", h=self.n_heads) + k = einsum("b s i, s k -> b k i", k, self.E) + v = einsum("b s i, s k -> b k i", v, self.F) + + k = einops.rearrange(k, "b k (h d) -> b h k d", d=self.head_dim) + v = einops.rearrange(v, "b k (h d) -> b h k d", d=self.head_dim) + + scores = einsum("b h s d, b h k d -> b h s k", q, k) / math.sqrt(self.head_dim) + attn_weights = scores.softmax(dim=-1) + self.attn_weights = attn_weights + attn_weights = self.dropout(attn_weights) + output = einsum("b h s k, b h k d -> b h s d", attn_weights, v) + output = einops.rearrange(output, "b h s d -> b s (h d)") + + if self.out_proj is not None: + output = self.out_proj(output) + + return output + + +class AdditiveAttention(nn.Module): + def __init__( + self, + input_dim: int, + n_heads: int, + use_bias: bool, + dropout: float, + share_qv_weights: bool, + ): + super(AdditiveAttention, self).__init__() + + assert input_dim % n_heads == 0, "'input_dim' must be divisible by 'n_heads'" + + self.head_dim = input_dim // n_heads + self.n_heads = n_heads + self.share_qv_weights = share_qv_weights + + self.dropout = nn.Dropout(dropout) + + # In the paper: " [...] we share the value and query transformation + # parameters to reduce the memory cost [...]" + if share_qv_weights: + self.qv_proj = nn.Linear(input_dim, input_dim, bias=use_bias) + else: + self.q_proj = nn.Linear(input_dim, input_dim, bias=use_bias) + self.v_proj = nn.Linear(input_dim, input_dim, bias=use_bias) + self.k_proj = nn.Linear(input_dim, input_dim, bias=use_bias) + + self.W_q = nn.Linear(input_dim, n_heads) + self.W_k = nn.Linear(input_dim, n_heads) + + self.r_out = nn.Linear(input_dim, input_dim) + + def forward(self, X: Tensor) -> Tensor: + # b: batch size + # s: seq length + # h: number of attention heads, + # d: head_dim + q = self.qv_proj(X) if self.share_qv_weights else self.q_proj(X) + v = self.qv_proj(X) if self.share_qv_weights else self.v_proj(X) + k = self.k_proj(X) + + alphas = (self.W_q(q) / math.sqrt(self.head_dim)).softmax(dim=-1) + q_r = einops.rearrange(q, "b s (h d) -> b s h d", h=self.n_heads) + global_query = einsum(" b s h, b s h d -> b h d", alphas, q_r) + global_query = einops.rearrange(global_query, "b h d -> b () (h d)") + + p = k * global_query + + betas = (self.W_k(p) / math.sqrt(self.head_dim)).softmax(dim=-1) + p_r = einops.rearrange(p, "b s (h d) -> b s h d", h=self.n_heads) + global_key = einsum(" b s h, b s h d -> b h d", betas, p_r) + global_key = einops.rearrange(global_key, "b h d -> b () (h d)") + + u = v * global_key + + # for consistency with all other transformer-based models, rearrange + # the attn_weights + self.attn_weights = ( + einops.rearrange(alphas, "b s h -> b h s"), + einops.rearrange(betas, "b s h -> b h s"), + ) + + output = q + self.dropout(self.r_out(u)) + + return output diff --git a/pytorch_widedeep/models/transformers/_embeddings_layers.py b/pytorch_widedeep/models/transformers/_embeddings_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..2ce7da83e0323bb095cd453be9b15950925097b9 --- /dev/null +++ b/pytorch_widedeep/models/transformers/_embeddings_layers.py @@ -0,0 +1,246 @@ +""" +SharedEmbeddings is inspired by the TabTransformer available in AutoGluon: +https://github.com/awslabs/autogluon/tree/master/tabular/src/autogluon/tabular/models/tab_transformer +""" + +import math + +import torch +from torch import nn + +from pytorch_widedeep.wdtypes import * # noqa: F403 +from pytorch_widedeep.models.tab_mlp import get_activation_fn + + +class FullEmbeddingDropout(nn.Module): + def __init__(self, dropout: float): + super(FullEmbeddingDropout, self).__init__() + self.dropout = dropout + + def forward(self, X: Tensor) -> Tensor: + mask = X.new().resize_((X.size(1), 1)).bernoulli_(1 - self.dropout).expand_as( + X + ) / (1 - self.dropout) + return mask * X + + +DropoutLayers = Union[nn.Dropout, FullEmbeddingDropout] + + +class SharedEmbeddings(nn.Module): + def __init__( + self, + n_embed: int, + embed_dim: int, + embed_dropout: float, + full_embed_dropout: bool = False, + add_shared_embed: bool = False, + frac_shared_embed=0.25, + ): + super(SharedEmbeddings, self).__init__() + + assert frac_shared_embed < 1, "'frac_shared_embed' must be less than 1" + self.add_shared_embed = add_shared_embed + self.embed = nn.Embedding(n_embed, embed_dim, padding_idx=0) + self.embed.weight.data.clamp_(-2, 2) + if add_shared_embed: + col_embed_dim = embed_dim + else: + col_embed_dim = int(embed_dim * frac_shared_embed) + self.shared_embed = nn.Parameter(torch.empty(1, col_embed_dim).uniform_(-1, 1)) + + if full_embed_dropout: + self.dropout: DropoutLayers = FullEmbeddingDropout(embed_dropout) + else: + self.dropout = nn.Dropout(embed_dropout) + + def forward(self, X: Tensor) -> Tensor: + out = self.dropout(self.embed(X)) + shared_embed = self.shared_embed.expand(out.shape[0], -1) + if self.add_shared_embed: + out += shared_embed + else: + out[:, : shared_embed.shape[1]] = shared_embed + return out + + +class ContinuousEmbeddings(nn.Module): + def __init__( + self, + n_cont_cols: int, + embed_dim: int, + use_bias: bool, + activation: str = None, + ): + super(ContinuousEmbeddings, self).__init__() + + self.weight = nn.Parameter(torch.Tensor(n_cont_cols, embed_dim)) + + self.bias = ( + nn.Parameter(torch.Tensor(n_cont_cols, embed_dim)) if use_bias else None + ) + self._reset_parameters() + + self.act_fn = get_activation_fn(activation) if activation else None + + def _reset_parameters(self) -> None: + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, X: Tensor) -> Tensor: + x = self.weight.unsqueeze(0) * X.unsqueeze(2) + + if self.bias is not None: + x = x + self.bias.unsqueeze(0) + + if self.act_fn is not None: + x = self.act_fn(x) + + return x + + +class CategoricalEmbeddings(nn.Module): + def __init__( + self, + embed_dim: int, + column_idx: Dict[str, int], + embed_input: Optional[List[Tuple[str, int]]], + embed_dropout: float, + full_embed_dropout: bool, + shared_embed: bool, + add_shared_embed: bool, + frac_shared_embed: float, + use_bias: bool, + ): + super(CategoricalEmbeddings, self).__init__() + self.column_idx = column_idx + self.embed_input = embed_input + self.embed_dropout = embed_dropout + self.shared_embed = shared_embed + + self.n_tokens = sum([ei[1] for ei in embed_input]) + self.categorical_cols = [ei[0] for ei in embed_input] + self.cat_idx = [self.column_idx[col] for col in self.categorical_cols] + + self.bias = ( + nn.Parameter(torch.Tensor(len(self.categorical_cols), embed_dim)) + if use_bias + else None + ) + if self.bias is not None: + nn.init.kaiming_uniform_(self.bias, a=math.sqrt(5)) + + # Categorical: val + 1 because 0 is reserved for padding/unseen cateogories. + if self.shared_embed: + self.embed: Union[nn.ModuleDict, nn.Embedding] = nn.ModuleDict( + { + "emb_layer_" + + col: SharedEmbeddings( + val if col == "cls_token" else val + 1, + embed_dim, + embed_dropout, + full_embed_dropout, + add_shared_embed, + frac_shared_embed, + ) + for col, val in self.embed_input + } + ) + else: + self.embed = nn.Embedding(self.n_tokens + 1, embed_dim, padding_idx=0) + if full_embed_dropout: + self.dropout: DropoutLayers = FullEmbeddingDropout(embed_dropout) + else: + self.dropout = nn.Dropout(embed_dropout) + + def forward(self, X: Tensor) -> Tensor: + if self.shared_embed: + cat_embed = [ + self.embed["emb_layer_" + col]( # type: ignore[index] + X[:, self.column_idx[col]].long() + ).unsqueeze(1) + for col, _ in self.embed_input + ] + x = torch.cat(cat_embed, 1) + else: + x = self.embed(X[:, self.cat_idx].long()) + x = self.dropout(x) + + if self.bias is not None: + x = x + self.bias.unsqueeze(0) + + return x + + +class CatAndContEmbeddings(nn.Module): + def __init__( + self, + embed_dim: int, + column_idx: Dict[str, int], + embed_input: Optional[List[Tuple[str, int]]], + embed_dropout: float, + full_embed_dropout: bool, + shared_embed: bool, + add_shared_embed: bool, + frac_shared_embed: float, + use_embed_bias: bool, + continuous_cols: Optional[List[str]], + embed_continuous: bool, + embed_continuous_activation: str, + use_cont_bias: bool, + cont_norm_layer: str, + ): + super(CatAndContEmbeddings, self).__init__() + + self.embed_input = embed_input + self.continuous_cols = continuous_cols + self.embed_continuous = embed_continuous + + # Categorical + if embed_input is not None: + self.cat_embed = CategoricalEmbeddings( + embed_dim, + column_idx, + embed_input, + embed_dropout, + full_embed_dropout, + shared_embed, + add_shared_embed, + frac_shared_embed, + use_embed_bias, + ) + # Continuous + if continuous_cols is not None: + self.cont_idx = [column_idx[col] for col in continuous_cols] + if cont_norm_layer == "layernorm": + self.cont_norm: NormLayers = nn.LayerNorm(len(continuous_cols)) + elif cont_norm_layer == "batchnorm": + self.cont_norm = nn.BatchNorm1d(len(continuous_cols)) + else: + self.cont_norm = nn.Identity() + if self.embed_continuous: + self.cont_embed = ContinuousEmbeddings( + len(continuous_cols), + embed_dim, + use_cont_bias, + embed_continuous_activation, + ) + + def forward(self, X: Tensor) -> Tuple[Tensor, Any]: + + if self.embed_input is not None: + x_cat = self.cat_embed(X) + else: + x_cat = None + + if self.continuous_cols is not None: + x_cont = self.cont_norm((X[:, self.cont_idx].float())) + if self.embed_continuous: + x_cont = self.cont_embed(x_cont) + else: + x_cont = None + + return x_cat, x_cont diff --git a/pytorch_widedeep/models/transformers/_encoders.py b/pytorch_widedeep/models/transformers/_encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..84643b6323a7cc0cb887bdd415f026feed779e37 --- /dev/null +++ b/pytorch_widedeep/models/transformers/_encoders.py @@ -0,0 +1,193 @@ +import einops +from torch import nn + +from pytorch_widedeep.wdtypes import * # noqa: F403 +from pytorch_widedeep.models.transformers._attention_layers import ( + AddNorm, + NormAdd, + PositionwiseFF, + LinearAttention, + AdditiveAttention, + MultiHeadedAttention, +) + + +class TransformerEncoder(nn.Module): + def __init__( + self, + input_dim: int, + n_heads: int, + use_bias: bool, + attn_dropout: float, + ff_dropout: float, + activation: str, + ): + super(TransformerEncoder, self).__init__() + + self.attn = MultiHeadedAttention( + input_dim, + n_heads, + use_bias, + attn_dropout, + ) + self.ff = PositionwiseFF(input_dim, ff_dropout, activation) + + self.attn_addnorm = AddNorm(input_dim, attn_dropout) + self.ff_addnorm = AddNorm(input_dim, ff_dropout) + + def forward(self, X: Tensor) -> Tensor: + x = self.attn_addnorm(X, self.attn) + return self.ff_addnorm(x, self.ff) + + +class SaintEncoder(nn.Module): + def __init__( + self, + input_dim: int, + n_heads: int, + use_bias: bool, + attn_dropout: float, + ff_dropout: float, + activation: str, + n_feat: int, + ): + super(SaintEncoder, self).__init__() + + self.n_feat = n_feat + + self.col_attn = MultiHeadedAttention( + input_dim, + n_heads, + use_bias, + attn_dropout, + ) + self.col_attn_ff = PositionwiseFF(input_dim, ff_dropout, activation) + self.col_attn_addnorm = AddNorm(input_dim, attn_dropout) + self.col_attn_ff_addnorm = AddNorm(input_dim, ff_dropout) + + self.row_attn = MultiHeadedAttention( + n_feat * input_dim, + n_heads, + use_bias, + attn_dropout, + ) + self.row_attn_ff = PositionwiseFF(n_feat * input_dim, ff_dropout, activation) + self.row_attn_addnorm = AddNorm(n_feat * input_dim, attn_dropout) + self.row_attn_ff_addnorm = AddNorm(n_feat * input_dim, ff_dropout) + + def forward(self, X: Tensor) -> Tensor: + x = self.col_attn_addnorm(X, self.col_attn) + x = self.col_attn_ff_addnorm(x, self.col_attn_ff) + x = einops.rearrange(x, "b n d -> 1 b (n d)") + x = self.row_attn_addnorm(x, self.row_attn) + x = self.row_attn_ff_addnorm(x, self.row_attn_ff) + x = einops.rearrange(x, "1 b (n d) -> b n d", n=self.n_feat) + return x + + +class FTTransformerEncoder(nn.Module): + def __init__( + self, + input_dim: int, + n_feats: int, + n_heads: int, + use_bias: bool, + attn_dropout: float, + ff_dropout: float, + kv_compression_factor: float, + kv_sharing: bool, + activation: str, + ff_factor: float, + first_block: bool, + ): + super(FTTransformerEncoder, self).__init__() + + self.first_block = first_block + + self.attn = LinearAttention( + input_dim, + n_feats, + n_heads, + use_bias, + attn_dropout, + kv_compression_factor, + kv_sharing, + ) + self.ff = PositionwiseFF(input_dim, ff_dropout, activation, ff_factor) + + self.attn_normadd = NormAdd(input_dim, attn_dropout) + self.ff_normadd = NormAdd(input_dim, ff_dropout) + + def forward(self, X: Tensor) -> Tensor: + if self.first_block: + x = X + self.attn(X) + else: + x = self.attn_normadd(X, self.attn) + return self.ff_normadd(x, self.ff) + + +class PerceiverEncoder(nn.Module): + def __init__( + self, + input_dim: int, + n_heads: int, + use_bias: bool, + attn_dropout: float, + ff_dropout: float, + activation: str, + query_dim: Optional[int] = None, + ): + super(PerceiverEncoder, self).__init__() + + self.attn = MultiHeadedAttention( + input_dim, + n_heads, + use_bias, + attn_dropout, + query_dim, + ) + attn_dim_out = query_dim if query_dim is not None else input_dim + self.ff = PositionwiseFF(attn_dim_out, ff_dropout, activation) + + self.ln_q = nn.LayerNorm(attn_dim_out) + self.ln_kv = nn.LayerNorm(input_dim) + self.norm_attn_dropout = nn.Dropout(attn_dropout) + + self.ff_norm = nn.LayerNorm(attn_dim_out) + self.norm_ff_dropout = nn.Dropout(ff_dropout) + + def forward(self, X_Q: Tensor, X_KV: Optional[Tensor] = None) -> Tensor: + x = self.ln_q(X_Q) + y = None if X_KV is None else self.ln_kv(X_KV) + x = x + self.norm_attn_dropout(self.attn(x, y)) + return x + self.norm_ff_dropout(self.ff(self.ff_norm(x))) + + +class FastFormerEncoder(nn.Module): + def __init__( + self, + input_dim: int, + n_heads: int, + use_bias: bool, + attn_dropout: float, + ff_dropout: float, + share_qv_weights: bool, + activation: str, + ): + super(FastFormerEncoder, self).__init__() + + self.attn = AdditiveAttention( + input_dim, + n_heads, + use_bias, + attn_dropout, + share_qv_weights, + ) + + self.ff = PositionwiseFF(input_dim, ff_dropout, activation) + self.attn_addnorm = AddNorm(input_dim, attn_dropout) + self.ff_addnorm = AddNorm(input_dim, ff_dropout) + + def forward(self, X: Tensor) -> Tensor: + x = self.attn_addnorm(X, self.attn) + return self.ff_addnorm(x, self.ff) diff --git a/pytorch_widedeep/models/transformers/ft_transformer.py b/pytorch_widedeep/models/transformers/ft_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5b7361b1520b866753eba46b0b66ba79abe72b20 --- /dev/null +++ b/pytorch_widedeep/models/transformers/ft_transformer.py @@ -0,0 +1,300 @@ +from torch import nn + +from pytorch_widedeep.wdtypes import * # noqa: F403 +from pytorch_widedeep.models.tab_mlp import MLP +from pytorch_widedeep.models.transformers._encoders import FTTransformerEncoder +from pytorch_widedeep.models.transformers._embeddings_layers import ( + CatAndContEmbeddings, +) + + +class FTTransformer(nn.Module): + r"""Defines a ``FTTransformer`` model + (`arXiv:2106.11959 `_) that can be + used as the ``deeptabular`` component of a Wide & Deep model. + + Parameters + ---------- + column_idx: Dict + Dict containing the index of the columns that will be passed through + the model. Required to slice the tensors. e.g. + {'education': 0, 'relationship': 1, 'workclass': 2, ...} + embed_input: List + List of Tuples with the column name and number of unique values + e.g. [('education', 11), ...] + embed_dropout: float, default = 0.1 + Dropout to be applied to the embeddings matrix + full_embed_dropout: bool, default = False + Boolean indicating if an entire embedding (i.e. the representation of + one column) will be dropped in the batch. See: + :obj:`pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`. + If ``full_embed_dropout = True``, ``embed_dropout`` is ignored. + shared_embed: bool, default = False + The idea behind ``shared_embed`` is described in the Appendix A in the + `TabTransformer paper `_: `'The + goal of having column embedding is to enable the model to distinguish + the classes in one column from those in the other columns'`. In other + words, the idea is to let the model learn which column is embedded + at the time. + add_shared_embed: bool, default = False, + The two embedding sharing strategies are: 1) add the shared embeddings to the column + embeddings or 2) to replace the first ``frac_shared_embed`` with the shared + embeddings. See :obj:`pytorch_widedeep.models.transformers._layers.SharedEmbeddings` + frac_shared_embed: float, default = 0.25 + The fraction of embeddings that will be shared (if ``add_shared_embed + = False``) by all the different categories for one particular + column. + continuous_cols: List, Optional, default = None + List with the name of the numeric (aka continuous) columns + embed_continuous_activation: str, default = None + String indicating the activation function to be applied to the + continuous embeddings, if any. ``tanh``, ``relu``, ``leaky_relu`` and + ``gelu`` are supported. + cont_norm_layer: str, default = None, + Type of normalization layer applied to the continuous features before + they are embedded. Options are: ``layernorm``, ``batchnorm`` or + ``None``. + input_dim: int, default = 64 + The so-called *dimension of the model*. Is the number of embeddings used to encode + the categorical and/or continuous columns. + kv_compression_factor: int, default = 0.5 + By default, the FTTransformer uses Linear Attention + (See `Linformer: Self-Attention with Linear Complexity + `_ ) The compression factor that + will be used to reduce the input sequence length. If we denote the + resulting sequence length as :math:`k` + :math:`k = int(kv_{compression \space factor} \times s)` + where :math:`s` is the input sequence length. + kv_sharing: bool, default = False + Boolean indicating if the :math:`E` and :math:`F` projection matrices + will share weights. See `Linformer: Self-Attention with Linear + Complexity `_ for details + n_heads: int, default = 8 + Number of attention heads per FTTransformer block + use_bias: bool, default = False + Boolean indicating whether or not to use bias in the Q, K, and V + projection layers + n_blocks: int, default = 4 + Number of FTTransformer blocks + attn_dropout: float, default = 0.2 + Dropout that will be applied to the Linear-Attention layers + ff_dropout: float, default = 0.1 + Dropout that will be applied to the FeedForward network + transformer_activation: str, default = "gelu" + Transformer Encoder activation function. ``tanh``, ``relu``, + ``leaky_relu``, ``gelu``, ``geglu`` and ``reglu`` are supported + ff_factor: float, default = 4 / 3 + Multiplicative factor applied to the first layer of the FF network in + each Transformer block, This is normally set to 4, but they use 4/3 + in the paper. + mlp_hidden_dims: List, Optional, default = None + MLP hidden dimensions. If not provided no MLP on top of the final + FTTransformer block will be used + mlp_activation: str, default = "relu" + MLP activation function. ``tanh``, ``relu``, ``leaky_relu`` and + ``gelu`` are supported + mlp_dropout: float, default = 0.1 + Dropout that will be applied to the final MLP + mlp_batchnorm: bool, default = False + Boolean indicating whether or not to apply batch normalization to the + dense layers + mlp_batchnorm_last: bool, default = False + Boolean indicating whether or not to apply batch normalization to the + last of the dense layers + mlp_linear_first: bool, default = False + Boolean indicating whether the order of the operations in the dense + layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP -> + LIN -> ACT]`` + + Attributes + ---------- + cat_and_cont_embed: ``nn.Module`` + This is the module that processes the categorical and continuous columns + transformer_blks: ``nn.Sequential`` + Sequence of FTTransformer blocks + transformer_mlp: ``nn.Module`` + MLP component in the model + output_dim: int + The output dimension of the model. This is a required attribute + neccesary to build the WideDeep class + + Example + -------- + >>> import torch + >>> from pytorch_widedeep.models import FTTransformer + >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1) + >>> colnames = ['a', 'b', 'c', 'd', 'e'] + >>> embed_input = [(u,i) for u,i in zip(colnames[:4], [4]*4)] + >>> continuous_cols = ['e'] + >>> column_idx = {k:v for v,k in enumerate(colnames)} + >>> model = FTTransformer(column_idx=column_idx, embed_input=embed_input, continuous_cols=continuous_cols) + >>> out = model(X_tab) + """ + + def __init__( + self, + column_idx: Dict[str, int], + embed_input: List[Tuple[str, int]], + embed_dropout: float = 0.1, + full_embed_dropout: bool = False, + shared_embed: bool = False, + add_shared_embed: bool = False, + frac_shared_embed: float = 0.25, + continuous_cols: Optional[List[str]] = None, + embed_continuous_activation: str = None, + cont_norm_layer: str = None, + input_dim: int = 64, + kv_compression_factor: float = 0.5, + kv_sharing: bool = False, + use_bias: bool = False, + n_heads: int = 8, + n_blocks: int = 4, + attn_dropout: float = 0.2, + ff_dropout: float = 0.1, + transformer_activation: str = "reglu", + ff_factor: float = 1.33, + mlp_hidden_dims: Optional[List[int]] = None, + mlp_activation: str = "relu", + mlp_dropout: float = 0.1, + mlp_batchnorm: bool = False, + mlp_batchnorm_last: bool = False, + mlp_linear_first: bool = True, + ): + super(FTTransformer, self).__init__() + + self.column_idx = column_idx + self.embed_input = embed_input + self.embed_dropout = embed_dropout + self.full_embed_dropout = full_embed_dropout + self.shared_embed = shared_embed + self.add_shared_embed = add_shared_embed + self.frac_shared_embed = frac_shared_embed + self.continuous_cols = continuous_cols + self.embed_continuous_activation = embed_continuous_activation + self.cont_norm_layer = cont_norm_layer + self.input_dim = input_dim + self.kv_compression_factor = kv_compression_factor + self.kv_sharing = kv_sharing + self.use_bias = use_bias + self.n_heads = n_heads + self.n_blocks = n_blocks + self.attn_dropout = attn_dropout + self.ff_dropout = ff_dropout + self.transformer_activation = transformer_activation + self.ff_factor = ff_factor + self.mlp_hidden_dims = mlp_hidden_dims + self.mlp_activation = mlp_activation + self.mlp_dropout = mlp_dropout + self.mlp_batchnorm = mlp_batchnorm + self.mlp_batchnorm_last = mlp_batchnorm_last + self.mlp_linear_first = mlp_linear_first + + self.with_cls_token = "cls_token" in column_idx + self.n_cat = len(embed_input) if embed_input is not None else 0 + self.n_cont = len(continuous_cols) if continuous_cols is not None else 0 + self.n_feats = self.n_cat + self.n_cont + + if self.n_cont and not self.n_cat and not self.embed_continuous: + raise ValueError( + "If only continuous features are used 'embed_continuous' must be set to 'True'" + ) + + self.cat_and_cont_embed = CatAndContEmbeddings( + input_dim, + column_idx, + embed_input, + embed_dropout, + full_embed_dropout, + shared_embed, + add_shared_embed, + frac_shared_embed, + True, # use_embed_bias + continuous_cols, + True, # embed_continuous, + embed_continuous_activation, + True, # use_cont_bias + cont_norm_layer, + ) + + is_first = True + self.transformer_blks = nn.Sequential() + for i in range(n_blocks): + self.transformer_blks.add_module( + "fttransformer_block" + str(i), + FTTransformerEncoder( + input_dim, + self.n_feats, + n_heads, + use_bias, + attn_dropout, + ff_dropout, + kv_compression_factor, + kv_sharing, + transformer_activation, + ff_factor, + is_first, + ), + ) + is_first = False + + if mlp_hidden_dims is not None: + attn_output_dim = ( + self.input_dim + if self.with_cls_token + else (self.n_cat + self.n_cont) * self.input_dim + ) + assert mlp_hidden_dims[0] == attn_output_dim, ( + f"The input dim of the MLP must be {attn_output_dim}. " + f"Got {mlp_hidden_dims[0]} instead" + ) + self.transformer_mlp = MLP( + mlp_hidden_dims, + mlp_activation, + mlp_dropout, + mlp_batchnorm, + mlp_batchnorm_last, + mlp_linear_first, + ) + # the output_dim attribute will be used as input_dim when "merging" the models + self.output_dim = mlp_hidden_dims[-1] + else: + self.transformer_mlp = None + self.output_dim = ( + input_dim if self.with_cls_token else (self.n_feats * input_dim) + ) + + def forward(self, X: Tensor) -> Tensor: + + x_cat, x_cont = self.cat_and_cont_embed(X) + + if x_cat is not None: + x = x_cat + if x_cont is not None: + x = torch.cat([x, x_cont], 1) if x_cat is not None else x_cont + + x = self.transformer_blks(x) + + if self.with_cls_token: + x = x[:, 0, :] + else: + x = x.flatten(1) + + if self.transformer_mlp is not None: + x = self.transformer_mlp(x) + + return x + + @property + def attention_weights(self) -> List: + r"""List with the attention weights + + The shape of the attention weights is: + + :math:`(N, H, F, k)` + + where *N* is the batch size, *H* is the number of attention heads, *F* + is the number of features/columns and *k* is the reduced sequence + length or dimension, i.e. :math:`k = int(kv_ + {compression \space factor} \times s)` + """ + return [blk.attn.attn_weights for blk in self.transformer_blks] diff --git a/pytorch_widedeep/models/transformers/layers.py b/pytorch_widedeep/models/transformers/layers.py deleted file mode 100644 index 0c6ef5430c24acb3bab13512e004989b8b478194..0000000000000000000000000000000000000000 --- a/pytorch_widedeep/models/transformers/layers.py +++ /dev/null @@ -1,257 +0,0 @@ -""" -The code in this module is inspired by a number of implementations: - -Classes PositionwiseFF and AddNorm are 'stolen' with much gratitude from the fantastic d2l.ai book: -https://d2l.ai/chapter_attention-mechanisms/transformer.html - -MultiHeadedAttention is inspired by the TabTransformer implementation here: -https://github.com/lucidrains/tab-transformer-pytorch. General comment: just go and have a look to -https://github.com/lucidrains - -SharedEmbeddings is inspired by the TabTransformer available in AutoGluon: -https://github.com/awslabs/autogluon/tree/master/tabular/src/autogluon/tabular/models/tab_transformer -If you have not checked that library, you should. - -The forward pass of the SaintEncoder is based on the original code release: -https://github.com/somepago/saint -""" - -import math - -import torch -import einops -from torch import nn, einsum - -from pytorch_widedeep.wdtypes import * # noqa: F403 -from pytorch_widedeep.models.tab_mlp import _get_activation_fn - - -class PositionwiseFF(nn.Module): - def __init__( - self, - input_dim: int, - ff_hidden_dim: int, - dropout: float, - activation: str, - ): - super(PositionwiseFF, self).__init__() - self.w_1 = nn.Linear( - input_dim, ff_hidden_dim * 2 if activation == "geglu" else ff_hidden_dim - ) - self.w_2 = nn.Linear(ff_hidden_dim, input_dim) - self.dropout = nn.Dropout(dropout) - self.activation = _get_activation_fn(activation) - - def forward(self, X: Tensor) -> Tensor: - return self.w_2(self.dropout(self.activation(self.w_1(X)))) - - -class AddNorm(nn.Module): - def __init__(self, input_dim: int, dropout: float): - super(AddNorm, self).__init__() - self.dropout = nn.Dropout(dropout) - self.ln = nn.LayerNorm(input_dim) - - def forward(self, X: Tensor, Y: Tensor) -> Tensor: - return self.ln(self.dropout(Y) + X) - - -class MultiHeadedAttention(nn.Module): - def __init__( - self, - input_dim: int, - n_heads: int, - keep_attn_weights: bool, - dropout: float, - ): - super(MultiHeadedAttention, self).__init__() - - assert input_dim % n_heads == 0, "'input_dim' must be divisible by 'n_heads'" - # Consistent with other implementations I assume d_v = d_k - self.d_k = input_dim // n_heads - self.n_heads = n_heads - self.dropout = nn.Dropout(dropout) - self.inp_proj = nn.Linear(input_dim, input_dim * 3) - self.out_proj = nn.Linear(input_dim, input_dim) - self.keep_attn_weights = keep_attn_weights - - def forward(self, X: Tensor) -> Tensor: - # b: batch size, s: src seq length (num of categorical features - # encoded as embeddings), l: target sequence (l = s), e: embeddings - # dimensions, h: number of attention heads, d: d_k - q, k, v = self.inp_proj(X).chunk(3, dim=2) - q, k, v = map( - lambda t: einops.rearrange(t, "b s (h d) -> b h s d", h=self.n_heads), - (q, k, v), - ) - scores = einsum("b h s d, b h l d -> b h s l", q, k) / math.sqrt(self.d_k) - attn_weights = self.dropout(scores.softmax(dim=-1)) - if self.keep_attn_weights: - self.attn_weights = attn_weights - attn_output = einsum("b h s l, b h l d -> b h s d", attn_weights, v) - output = einops.rearrange(attn_output, "b h s d -> b s (h d)", h=self.n_heads) - - return self.out_proj(output) - - -class TransformerEncoder(nn.Module): - def __init__( - self, - input_dim: int, - n_heads: int, - keep_attn_weights: bool, - ff_hidden_dim: int, - dropout: float, - activation: str, - ): - super(TransformerEncoder, self).__init__() - - self.self_attn = MultiHeadedAttention( - input_dim, - n_heads, - keep_attn_weights, - dropout, - ) - self.ff = PositionwiseFF(input_dim, ff_hidden_dim, dropout, activation) - self.attn_addnorm = AddNorm(input_dim, dropout) - self.ff_addnorm = AddNorm(input_dim, dropout) - - def forward(self, X: Tensor) -> Tensor: - x = self.attn_addnorm(X, self.self_attn(X)) - return self.ff_addnorm(x, self.ff(x)) - - -class SaintEncoder(nn.Module): - def __init__( - self, - input_dim: int, - n_heads: int, - keep_attn_weights: bool, - ff_hidden_dim: int, - dropout: float, - activation: str, - n_feat: int, - ): - super(SaintEncoder, self).__init__() - - self.n_feat = n_feat - - self.self_attn = MultiHeadedAttention( - input_dim, - n_heads, - keep_attn_weights, - dropout, - ) - self.self_attn_ff = PositionwiseFF( - input_dim, ff_hidden_dim, dropout, activation - ) - self.self_attn_addnorm = AddNorm(input_dim, dropout) - self.self_attn_ff_addnorm = AddNorm(input_dim, dropout) - - self.row_attn = MultiHeadedAttention( - n_feat * input_dim, - n_heads, - keep_attn_weights, - dropout, - ) - self.row_attn_ff = PositionwiseFF( - n_feat * input_dim, n_feat * ff_hidden_dim, dropout, activation - ) - self.row_attn_addnorm = AddNorm(n_feat * input_dim, dropout) - self.row_attn_ff_addnorm = AddNorm(n_feat * input_dim, dropout) - - def forward(self, X: Tensor) -> Tensor: - x = self.self_attn_addnorm(X, self.self_attn(X)) - x = self.self_attn_ff_addnorm(x, self.self_attn_ff(x)) - x = einops.rearrange(x, "b n d -> 1 b (n d)") - x = self.row_attn_addnorm(x, self.row_attn(x)) - x = self.row_attn_ff_addnorm(x, self.row_attn_ff(x)) - x = einops.rearrange(x, "1 b (n d) -> b n d", n=self.n_feat) - return x - - -class FullEmbeddingDropout(nn.Module): - def __init__(self, dropout: float): - super(FullEmbeddingDropout, self).__init__() - self.dropout = dropout - - def forward(self, X: Tensor) -> Tensor: - mask = X.new().resize_((X.size(1), 1)).bernoulli_(1 - self.dropout).expand_as( - X - ) / (1 - self.dropout) - return mask * X - - -class SharedEmbeddings(nn.Module): - def __init__( - self, - n_embed: int, - embed_dim: int, - embed_dropout: float, - full_embed_dropout: bool = False, - add_shared_embed: bool = False, - frac_shared_embed=0.25, - ): - super(SharedEmbeddings, self).__init__() - - assert frac_shared_embed < 1, "'frac_shared_embed' must be less than 1" - self.add_shared_embed = add_shared_embed - self.embed = nn.Embedding(n_embed, embed_dim, padding_idx=0) - self.embed.weight.data.clamp_(-2, 2) - if add_shared_embed: - col_embed_dim = embed_dim - else: - col_embed_dim = int(embed_dim * frac_shared_embed) - self.shared_embed = nn.Parameter(torch.empty(1, col_embed_dim).uniform_(-1, 1)) - - if full_embed_dropout: - self.dropout: DropoutLayers = FullEmbeddingDropout(embed_dropout) - else: - self.dropout = nn.Dropout(embed_dropout) - - def forward(self, X: Tensor) -> Tensor: - out = self.dropout(self.embed(X)) - shared_embed = self.shared_embed.expand(out.shape[0], -1) - if self.add_shared_embed: - out += shared_embed - else: - out[:, : shared_embed.shape[1]] = shared_embed - return out - - -class ContinuousEmbeddings(nn.Module): - def __init__( - self, - n_cont_cols: int, - embed_dim: int, - activation: str = None, - bias: bool = True, - ): - super(ContinuousEmbeddings, self).__init__() - self.n_cont_cols = n_cont_cols - self.embed_dim = embed_dim - self.activation = activation - - self.weight = nn.Parameter(torch.Tensor(n_cont_cols, embed_dim)) - self.bias = nn.Parameter(torch.Tensor(n_cont_cols, embed_dim)) if bias else None - self._reset_parameters() - - self.act_fn = _get_activation_fn(activation) if activation else None - - def _reset_parameters(self) -> None: - nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - nn.init.uniform_(self.bias, -bound, bound) - - def forward(self, X: Tensor) -> Tensor: - x = self.weight.unsqueeze(0) * X.unsqueeze(2) - - if self.bias is not None: - x = x + self.bias.unsqueeze(0) - - if self.act_fn is not None: - x = self.act_fn(x) - - return x diff --git a/pytorch_widedeep/models/transformers/saint.py b/pytorch_widedeep/models/transformers/saint.py index 403cdbdfa7712d24099ec4b33112ce42d3bb194c..40c52927f8df2eb93e8829eb44c51a752f8e48e6 100644 --- a/pytorch_widedeep/models/transformers/saint.py +++ b/pytorch_widedeep/models/transformers/saint.py @@ -1,82 +1,86 @@ from torch import nn from pytorch_widedeep.wdtypes import * # noqa: F403 -from pytorch_widedeep.models.transformers.layers import SaintEncoder -from pytorch_widedeep.models.transformers.tab_transformer import TabTransformer +from pytorch_widedeep.models.tab_mlp import MLP +from pytorch_widedeep.models.transformers._encoders import SaintEncoder +from pytorch_widedeep.models.transformers._embeddings_layers import ( + CatAndContEmbeddings, +) -class SAINT(TabTransformer): - r"""Adaptation of SAINT model - (https://arxiv.org/abs/2106.01342) model that can be used as the - deeptabular component of a Wide & Deep model. - - Parameters for this model are identical to those of the ``TabTransformer`` +class SAINT(nn.Module): + r"""Defines a ``SAINT`` model + (`arXiv:2106.01342 `_) that can be used + as the ``deeptabular`` component of a Wide & Deep model. Parameters ---------- column_idx: Dict Dict containing the index of the columns that will be passed through - the DeepDense model. Required to slice the tensors. e.g. {'education': - 0, 'relationship': 1, 'workclass': 2, ...} + the model. Required to slice the tensors. e.g. + {'education': 0, 'relationship': 1, 'workclass': 2, ...} embed_input: List List of Tuples with the column name and number of unique values - e.g. [(education, 11), ...] + e.g. [('education', 11), ...] embed_dropout: float, default = 0.1 Dropout to be applied to the embeddings matrix full_embed_dropout: bool, default = False Boolean indicating if an entire embedding (i.e. the representation of one column) will be dropped in the batch. See: - :obj:`pytorch_widedeep.models.transformers.layers.FullEmbeddingDropout`. + :obj:`pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`. If ``full_embed_dropout = True``, ``embed_dropout`` is ignored. shared_embed: bool, default = False - The idea behind ``shared_embed`` is described in the Appendix A in the paper: - `'The goal of having column embedding is to enable the model to distinguish the - classes in one column from those in the other columns'`. In other words, the idea - is to let the model learn which column is embedding at the time. - add_shared_embed: bool, default = False, - The two embedding sharing strategies are: 1) add the shared embeddings to the column - embeddings or 2) to replace the first ``frac_shared_embed`` with the shared - embeddings. See :obj:`pytorch_widedeep.models.transformers.layers.SharedEmbeddings` + The idea behind ``shared_embed`` is described in the Appendix A in the + `TabTransformer paper `_: `'The + goal of having column embedding is to enable the model to distinguish + the classes in one column from those in the other columns'`. In other + words, the idea is to let the model learn which column is embedded + at the time. + add_shared_embed: bool, default = False + The two embedding sharing strategies are: 1) add the shared embeddings + to the column embeddings or 2) to replace the first + ``frac_shared_embed`` with the shared embeddings. + See :obj:`pytorch_widedeep.models.transformers._layers.SharedEmbeddings` frac_shared_embed: float, default = 0.25 - The fraction of embeddings that will be shared by all the different categories for - one particular column. + The fraction of embeddings that will be shared (if ``add_shared_embed + = False``) by all the different categories for one particular + column. continuous_cols: List, Optional, default = None List with the name of the numeric (aka continuous) columns - embed_continuous: bool, default = False, - Boolean indicating if the continuous features will be "embedded". See - ``pytorch_widedeep.models.transformers.layers.ContinuousEmbeddings`` - embed_continuous_activation: str, default = "relu" + embed_continuous_activation: str, default = None String indicating the activation function to be applied to the - continuous embeddings, if any. - 'relu', 'leaky_relu' and 'gelu' are supported. - cont_norm_layer: str, default = "layernorm", - Type of normalization layer applied to the continuous features if they - are not embedded. Options are: 'layernorm' or 'batchnorm'. + continuous embeddings, if any. ``tanh``, ``relu``, ``leaky_relu`` and + ``gelu`` are supported. + cont_norm_layer: str, default = None, + Type of normalization layer applied to the continuous features before + they are embedded. Options are: ``layernorm``, ``batchnorm`` or + ``None``. input_dim: int, default = 32 - The so-called *dimension of the model*. Is the number of embeddings used to encode - the categorical columns + The so-called *dimension of the model*. In general is the number of + embeddings used to encode the categorical and/or continuous columns n_heads: int, default = 8 Number of attention heads per Transformer block - n_blocks: int, default = 6 - Number of Transformer blocks - dropout: float, default = 0.1 - Dropout that will be applied internally to the ``TransformerEncoder`` - (see :obj:`pytorch_widedeep.models.transformers.layers.TransformerEncoder`) - and the output MLP - keep_attn_weights: bool, default = False - If set to ``True`` the model will store the attention weights in the ``attention_weights`` - attribute. - ff_hidden_dim: int, default = 128 - Hidden dimension of the ``FeedForward`` Layer. See - :obj:`pytorch_widedeep.models.transformers.layers.FeedForward`. + use_bias: bool, default = False + Boolean indicating whether or not to use bias in the Q, K, and V + projection layers + n_blocks: int, default = 2 + Number of SAINT-Transformer blocks. 1 in the paper. + attn_dropout: float, default = 0.2 + Dropout that will be applied to the Multi-Head Attention column and + row layers + ff_dropout: float, default = 0.1 + Dropout that will be applied to the FeedForward network transformer_activation: str, default = "gelu" - Transformer Encoder activation function. 'relu', 'leaky_relu', 'gelu' - and 'geglu' are supported + Transformer Encoder activation function. ``tanh``, ``relu``, + ``leaky_relu``, ``gelu``, ``geglu`` and ``reglu`` are supported mlp_hidden_dims: List, Optional, default = None - MLP hidden dimensions. If not provided it will default to ``[4*l, - 2*l]`` where ``l`` is the mlp input dimension + MLP hidden dimensions. If not provided it will default to ``[l, 4*l, + 2*l]`` where ``l`` is the MLP input dimension mlp_activation: str, default = "relu" - MLP activation function. 'relu', 'leaky_relu' and 'gelu' are supported + MLP activation function. ``tanh``, ``relu``, ``leaky_relu`` and + ``gelu`` are supported + mlp_dropout: float, default = 0.1 + Dropout that will be applied to the final MLP mlp_batchnorm: bool, default = False Boolean indicating whether or not to apply batch normalization to the dense layers @@ -90,20 +94,10 @@ class SAINT(TabTransformer): Attributes ---------- - cat_embed_layers: ``nn.ModuleDict`` - Dict with the embeddings per column - cont_embed: ``nn.ModuleDict`` - Dict with the embeddings per column if ``embed_continuous=True`` - cont_norm_layer: NormLayers - Continuous normalization layer if ``continuous_cols`` is not None + cat_and_cont_embed: ``nn.Module`` + This is the module that processes the categorical and continuous columns transformer_blks: ``nn.Sequential`` - Sequence of Transformer blocks - attention_weights: List - List of tuples with the attention weights per block if - ``keep_attn_weights = True``. The first element in each tuples is the - attention weights corresponding to the self attention mechanism - (i.e. column attention) and the second element is the inter-sample - attention weights (i.e. row attention) + Sequence of SAINT-Transformer blocks transformer_mlp: ``nn.Module`` MLP component in the model output_dim: int @@ -133,23 +127,59 @@ class SAINT(TabTransformer): add_shared_embed: bool = False, frac_shared_embed: float = 0.25, continuous_cols: Optional[List[str]] = None, - embed_continuous: bool = False, embed_continuous_activation: str = None, - cont_norm_layer: str = "layernorm", + cont_norm_layer: str = None, input_dim: int = 32, + use_bias: bool = False, n_heads: int = 8, - n_blocks: int = 6, - dropout: float = 0.1, - keep_attn_weights: bool = False, - ff_hidden_dim: int = 32 * 4, + n_blocks: int = 2, + attn_dropout: float = 0.1, + ff_dropout: float = 0.2, transformer_activation: str = "gelu", mlp_hidden_dims: Optional[List[int]] = None, mlp_activation: str = "relu", + mlp_dropout: float = 0.1, mlp_batchnorm: bool = False, mlp_batchnorm_last: bool = False, mlp_linear_first: bool = True, ): - super().__init__( + super(SAINT, self).__init__() + + self.column_idx = column_idx + self.embed_input = embed_input + self.embed_dropout = embed_dropout + self.full_embed_dropout = full_embed_dropout + self.shared_embed = shared_embed + self.add_shared_embed = add_shared_embed + self.frac_shared_embed = frac_shared_embed + self.continuous_cols = continuous_cols + self.embed_continuous_activation = embed_continuous_activation + self.cont_norm_layer = cont_norm_layer + self.input_dim = input_dim + self.use_bias = use_bias + self.n_heads = n_heads + self.n_blocks = n_blocks + self.attn_dropout = attn_dropout + self.ff_dropout = ff_dropout + self.transformer_activation = transformer_activation + self.mlp_hidden_dims = mlp_hidden_dims + self.mlp_activation = mlp_activation + self.mlp_batchnorm = mlp_batchnorm + self.mlp_batchnorm_last = mlp_batchnorm_last + self.mlp_linear_first = mlp_linear_first + + self.with_cls_token = "cls_token" in column_idx + self.n_cat = len(embed_input) if embed_input is not None else 0 + self.n_cont = len(continuous_cols) if continuous_cols is not None else 0 + self.n_feats = self.n_cat + self.n_cont + + if self.n_cont and not self.n_cat and not self.embed_continuous: + raise ValueError( + "If only continuous features are used 'embed_continuous' must be set to 'True'" + ) + + self.cat_and_cont_embed = CatAndContEmbeddings( + input_dim, column_idx, embed_input, embed_dropout, @@ -157,39 +187,91 @@ class SAINT(TabTransformer): shared_embed, add_shared_embed, frac_shared_embed, + False, # use_embed_bias continuous_cols, - embed_continuous, + True, # embed_continuous, embed_continuous_activation, + True, # use_cont_bias cont_norm_layer, - input_dim, - n_heads, - n_blocks, - dropout, - keep_attn_weights, - ff_hidden_dim, - transformer_activation, - mlp_hidden_dims, - mlp_activation, - mlp_batchnorm, - mlp_batchnorm_last, - mlp_linear_first, ) - if embed_continuous: - n_feats = len(embed_input) + len(continuous_cols) - else: - n_feats = len(embed_input) self.transformer_blks = nn.Sequential() for i in range(n_blocks): self.transformer_blks.add_module( - "block" + str(i), + "saint_block" + str(i), SaintEncoder( input_dim, n_heads, - keep_attn_weights, - ff_hidden_dim, - dropout, + use_bias, + attn_dropout, + ff_dropout, transformer_activation, - n_feats, + self.n_feats, ), ) + + attn_output_dim = ( + self.input_dim if self.with_cls_token else self.n_feats * self.input_dim + ) + if not mlp_hidden_dims: + mlp_hidden_dims = [ + attn_output_dim, + attn_output_dim * 4, + attn_output_dim * 2, + ] + else: + assert mlp_hidden_dims[0] == attn_output_dim, ( + f"The input dim of the MLP must be {attn_output_dim}. " + f"Got {mlp_hidden_dims[0]} instead" + ) + self.transformer_mlp = MLP( + mlp_hidden_dims, + mlp_activation, + mlp_dropout, + mlp_batchnorm, + mlp_batchnorm_last, + mlp_linear_first, + ) + + # the output_dim attribute will be used as input_dim when "merging" the models + self.output_dim = mlp_hidden_dims[-1] + + def forward(self, X: Tensor) -> Tensor: + + x_cat, x_cont = self.cat_and_cont_embed(X) + + if x_cat is not None: + x = x_cat + if x_cont is not None: + x = torch.cat([x, x_cont], 1) if x_cat is not None else x_cont + + x = self.transformer_blks(x) + + if self.with_cls_token: + x = x[:, 0, :] + else: + x = x.flatten(1) + + return self.transformer_mlp(x) + + @property + def attention_weights(self) -> List: + r"""List with the attention weights. Each element of the list is a tuple + where the first and the second elements are the column and row + attention weights respectively + + The shape of the attention weights is: + + - column attention: :math:`(N, H, F, F)` + + - row attention: :math:`(1, H, N, N)` + + where *N* is the batch size, *H* is the number of heads and *F* is the + number of features/columns in the dataset + """ + attention_weights = [] + for blk in self.transformer_blks: + attention_weights.append( + (blk.col_attn.attn_weights, blk.row_attn.attn_weights) + ) + return attention_weights diff --git a/pytorch_widedeep/models/transformers/tab_fastformer.py b/pytorch_widedeep/models/transformers/tab_fastformer.py new file mode 100644 index 0000000000000000000000000000000000000000..75b0816ea8e911fd1fe56374d3376b0b969b09d1 --- /dev/null +++ b/pytorch_widedeep/models/transformers/tab_fastformer.py @@ -0,0 +1,300 @@ +from torch import nn + +from pytorch_widedeep.wdtypes import * # noqa: F403 +from pytorch_widedeep.models.tab_mlp import MLP +from pytorch_widedeep.models.transformers._encoders import FastFormerEncoder +from pytorch_widedeep.models.transformers._embeddings_layers import ( + CatAndContEmbeddings, +) + + +class TabFastFormer(nn.Module): + r"""Defines an adaptation of a ``FastFormer`` model + (`arXiv:2108.09084 `_) that can be used + as the ``deeptabular`` component of a Wide & Deep model. + + Parameters + ---------- + column_idx: Dict + Dict containing the index of the columns that will be passed through + the model. Required to slice the tensors. e.g. + {'education': 0, 'relationship': 1, 'workclass': 2, ...} + embed_input: List + List of Tuples with the column name and number of unique values + e.g. [('education', 11), ...] + embed_dropout: float, default = 0.1 + Dropout to be applied to the embeddings matrix + full_embed_dropout: bool, default = False + Boolean indicating if an entire embedding (i.e. the representation of + one column) will be dropped in the batch. See: + :obj:`pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`. + If ``full_embed_dropout = True``, ``embed_dropout`` is ignored. + shared_embed: bool, default = False + The idea behind ``shared_embed`` is described in the Appendix A in the + `TabTransformer paper `_: `'The + goal of having column embedding is to enable the model to distinguish + the classes in one column from those in the other columns'`. In other + words, the idea is to let the model learn which column is embedded + at the time. + add_shared_embed: bool, default = False, + The two embedding sharing strategies are: 1) add the shared embeddings + to the column embeddings or 2) to replace the first + ``frac_shared_embed`` with the shared embeddings. + See :obj:`pytorch_widedeep.models.transformers._layers.SharedEmbeddings` + frac_shared_embed: float, default = 0.25 + The fraction of embeddings that will be shared (if ``add_shared_embed + = False``) by all the different categories for one particular + column. + continuous_cols: List, Optional, default = None + List with the name of the numeric (aka continuous) columns + embed_continuous_activation: str, default = None + String indicating the activation function to be applied to the + continuous embeddings, if any. ``tanh``, ``relu``, ``leaky_relu`` and + ``gelu`` are supported. + cont_norm_layer: str, default = None, + Type of normalization layer applied to the continuous features before + they are embedded. Options are: ``layernorm``, ``batchnorm`` or + ``None``. + input_dim: int, default = 32 + The so-called *dimension of the model*. In general is the number of + embeddings used to encode the categorical and/or continuous columns + n_heads: int, default = 8 + Number of attention heads per FastFormer block + use_bias: bool, default = False + Boolean indicating whether or not to use bias in the Q, K, and V + projection layers + n_blocks: int, default = 4 + Number of FastFormer blocks + attn_dropout: float, default = 0.2 + Dropout that will be applied to the Additive Attention layers + ff_dropout: float, default = 0.1 + Dropout that will be applied to the FeedForward network + share_qv_weights: bool, default = False + Following the paper, this is a boolean indicating if the the value and + the query transformation parameters will be shared + share_weights: bool, default = False + In addition to sharing the value and query transformation parameters, + the parameters across different Fastformer layers can also be shared + transformer_activation: str, default = "gelu" + Transformer Encoder activation function. ``tanh``, ``relu``, + ``leaky_relu``, ``gelu``, ``geglu`` and ``reglu`` are supported + mlp_hidden_dims: List, Optional, default = None + MLP hidden dimensions. If not provided it will default to ``[l, 4*l, + 2*l]`` where ``l`` is the MLP input dimension + mlp_activation: str, default = "relu" + MLP activation function. ``tanh``, ``relu``, ``leaky_relu`` and + ``gelu`` are supported + mlp_dropout: float, default = 0.1 + Dropout that will be applied to the final MLP + mlp_batchnorm: bool, default = False + Boolean indicating whether or not to apply batch normalization to the + dense layers + mlp_batchnorm_last: bool, default = False + Boolean indicating whether or not to apply batch normalization to the + last of the dense layers + mlp_linear_first: bool, default = False + Boolean indicating whether the order of the operations in the dense + layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP -> + LIN -> ACT]`` + + Attributes + ---------- + cat_and_cont_embed: ``nn.Module`` + This is the module that processes the categorical and continuous columns + transformer_blks: ``nn.Sequential`` + Sequence of FasFormer blocks. + transformer_mlp: ``nn.Module`` + MLP component in the model + output_dim: int + The output dimension of the model. This is a required attribute + neccesary to build the WideDeep class + + Example + -------- + >>> import torch + >>> from pytorch_widedeep.models import TabFastFormer + >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1) + >>> colnames = ['a', 'b', 'c', 'd', 'e'] + >>> embed_input = [(u,i) for u,i in zip(colnames[:4], [4]*4)] + >>> continuous_cols = ['e'] + >>> column_idx = {k:v for v,k in enumerate(colnames)} + >>> model = TabFastFormer(column_idx=column_idx, embed_input=embed_input, continuous_cols=continuous_cols) + >>> out = model(X_tab) + """ + + def __init__( + self, + column_idx: Dict[str, int], + embed_input: Optional[List[Tuple[str, int]]] = None, + embed_dropout: float = 0.1, + full_embed_dropout: bool = False, + shared_embed: bool = False, + add_shared_embed: bool = False, + frac_shared_embed: float = 0.25, + continuous_cols: Optional[List[str]] = None, + embed_continuous_activation: str = None, + cont_norm_layer: str = None, + input_dim: int = 32, + n_heads: int = 8, + use_bias: bool = False, + n_blocks: int = 4, + attn_dropout: float = 0.1, + ff_dropout: float = 0.2, + share_qv_weights: bool = False, + share_weights: bool = False, + transformer_activation: str = "relu", + mlp_hidden_dims: Optional[List[int]] = None, + mlp_activation: str = "relu", + mlp_dropout: float = 0.1, + mlp_batchnorm: bool = False, + mlp_batchnorm_last: bool = False, + mlp_linear_first: bool = True, + ): + super(TabFastFormer, self).__init__() + + self.column_idx = column_idx + self.embed_input = embed_input + self.embed_dropout = embed_dropout + self.full_embed_dropout = full_embed_dropout + self.shared_embed = shared_embed + self.add_shared_embed = add_shared_embed + self.frac_shared_embed = frac_shared_embed + self.continuous_cols = continuous_cols + self.embed_continuous_activation = embed_continuous_activation + self.cont_norm_layer = cont_norm_layer + self.input_dim = input_dim + self.n_heads = n_heads + self.use_bias = use_bias + self.n_blocks = n_blocks + self.attn_dropout = attn_dropout + self.ff_dropout = ff_dropout + self.share_qv_weights = share_qv_weights + self.share_weights = share_weights + self.transformer_activation = transformer_activation + self.mlp_hidden_dims = mlp_hidden_dims + self.mlp_activation = mlp_activation + self.mlp_batchnorm = mlp_batchnorm + self.mlp_batchnorm_last = mlp_batchnorm_last + self.mlp_linear_first = mlp_linear_first + + self.with_cls_token = "cls_token" in column_idx + self.n_cat = len(embed_input) if embed_input is not None else 0 + self.n_cont = len(continuous_cols) if continuous_cols is not None else 0 + self.n_feats = self.n_cat + self.n_cont + + if self.n_cont and not self.n_cat and not self.embed_continuous: + raise ValueError( + "If only continuous features are used 'embed_continuous' must be set to 'True'" + ) + + self.cat_and_cont_embed = CatAndContEmbeddings( + input_dim, + column_idx, + embed_input, + embed_dropout, + full_embed_dropout, + shared_embed, + add_shared_embed, + frac_shared_embed, + False, # use_embed_bias + continuous_cols, + True, # embed_continuous, + embed_continuous_activation, + True, # use_cont_bias + cont_norm_layer, + ) + + self.transformer_blks = nn.Sequential() + first_fastformer_block = FastFormerEncoder( + input_dim, + n_heads, + use_bias, + attn_dropout, + ff_dropout, + share_qv_weights, + transformer_activation, + ) + self.transformer_blks.add_module("fastformer_block0", first_fastformer_block) + for i in range(1, n_blocks): + if share_weights: + self.transformer_blks.add_module( + "fastformer_block" + str(i), first_fastformer_block + ) + else: + self.transformer_blks.add_module( + "fastformer_block" + str(i), + FastFormerEncoder( + input_dim, + n_heads, + use_bias, + attn_dropout, + ff_dropout, + share_qv_weights, + transformer_activation, + ), + ) + + attn_output_dim = ( + self.input_dim + if self.with_cls_token + else (self.n_cat + self.n_cont) * self.input_dim + ) + if not mlp_hidden_dims: + mlp_hidden_dims = [ + attn_output_dim, + attn_output_dim * 4, + attn_output_dim * 2, + ] + else: + assert mlp_hidden_dims[0] == attn_output_dim, ( + f"The input dim of the MLP must be {attn_output_dim}. " + f"Got {mlp_hidden_dims[0]} instead" + ) + self.transformer_mlp = MLP( + mlp_hidden_dims, + mlp_activation, + mlp_dropout, + mlp_batchnorm, + mlp_batchnorm_last, + mlp_linear_first, + ) + + # the output_dim attribute will be used as input_dim when "merging" the models + self.output_dim = mlp_hidden_dims[-1] + + def forward(self, X: Tensor) -> Tensor: + + x_cat, x_cont = self.cat_and_cont_embed(X) + + if x_cat is not None: + x = x_cat + if x_cont is not None: + x = torch.cat([x, x_cont], 1) if x_cat is not None else x_cont + + x = self.transformer_blks(x) + + if self.with_cls_token: + x = x[:, 0, :] + else: + x = x.flatten(1) + + return self.transformer_mlp(x) + + @property + def attention_weights(self) -> List: + r"""List with the attention weights. Each element of the list is a + tuple where the first and second elements are :math:`\alpha` + and :math:`\beta` attention weights in the paper. + + The shape of the attention weights is: + + :math:`(N, H, F)` + + where *N* is the batch size, *H* is the number of attention heads + and *F* is the number of features/columns in the dataset + """ + if self.share_weights: + attention_weights = [self.transformer_blks[0].attn.attn_weight] + else: + attention_weights = [blk.attn.attn_weights for blk in self.transformer_blks] + return attention_weights diff --git a/pytorch_widedeep/models/transformers/tab_perceiver.py b/pytorch_widedeep/models/transformers/tab_perceiver.py new file mode 100644 index 0000000000000000000000000000000000000000..061a0b82fd6303797920fca800321f0f7e9d9b4b --- /dev/null +++ b/pytorch_widedeep/models/transformers/tab_perceiver.py @@ -0,0 +1,377 @@ +import torch +import einops +from torch import nn + +from pytorch_widedeep.wdtypes import * # noqa: F403 +from pytorch_widedeep.models.tab_mlp import MLP +from pytorch_widedeep.models.transformers._encoders import PerceiverEncoder +from pytorch_widedeep.models.transformers._embeddings_layers import ( + CatAndContEmbeddings, +) + + +class TabPerceiver(nn.Module): + r"""Defines an adaptation of a ``Perceiver`` model + (`arXiv:2103.03206 `_) that can be used + as the ``deeptabular`` component of a Wide & Deep model. + + Parameters + ---------- + column_idx: Dict + Dict containing the index of the columns that will be passed through + the model. Required to slice the tensors. e.g. + {'education': 0, 'relationship': 1, 'workclass': 2, ...} + embed_input: List + List of Tuples with the column name and number of unique values + e.g. [('education', 11), ...] + embed_dropout: float, default = 0.1 + Dropout to be applied to the embeddings matrix + full_embed_dropout: bool, default = False + Boolean indicating if an entire embedding (i.e. the representation of + one column) will be dropped in the batch. See: + :obj:`pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`. + If ``full_embed_dropout = True``, ``embed_dropout`` is ignored. + shared_embed: bool, default = False + The idea behind ``shared_embed`` is described in the Appendix A in the + `TabTransformer paper `_: `'The + goal of having column embedding is to enable the model to distinguish + the classes in one column from those in the other columns'`. In other + words, the idea is to let the model learn which column is embedded + at the time. + add_shared_embed: bool, default = False, + The two embedding sharing strategies are: 1) add the shared embeddings to the column + embeddings or 2) to replace the first ``frac_shared_embed`` with the shared + embeddings. See :obj:`pytorch_widedeep.models.transformers._layers.SharedEmbeddings` + frac_shared_embed: float, default = 0.25 + The fraction of embeddings that will be shared (if ``add_shared_embed + = False``) by all the different categories for one particular + column. + continuous_cols: List, Optional, default = None + List with the name of the numeric (aka continuous) columns + embed_continuous_activation: str, default = None + String indicating the activation function to be applied to the + continuous embeddings, if any. ``tanh``, ``relu``, ``leaky_relu`` and + ``gelu`` are supported. + cont_norm_layer: str, default = None, + Type of normalization layer applied to the continuous features before + they are embedded. Options are: ``layernorm``, ``batchnorm`` or + ``None``. + input_dim: int, default = 32 + The so-called *dimension of the model*. In general, is the number of + embeddings used to encode the categorical and/or continuous columns. + n_cross_attns: int, default = 1 + Number of times each perceiver block will cross attend to the input + data (i.e. number of cross attention components per perceiver block). + This should normally be 1. However, in the paper they describe some + architectures (normally computer vision-related problems) where the + Perceiver attends multiple times to the input array. Therefore, maybe + multiple cross attention to the input array is also useful in some + cases for tabular data + n_cross_attn_heads: int, default = 4 + Number of attention heads for the cross attention component + n_latents: int, default = 16 + Number of latents. This is the *N* parameter in the paper. As + indicated in the paper, this number should be significantly lower + than *M* (the number of columns in the dataset). Setting *N* closer + to *M* defies the main purpose of the Perceiver, which is to overcome + the transformer quadratic bottleneck + latent_dim: int, default = 128 + Latent dimension. + n_latent_heads: int, default = 4 + Number of attention heads per Latent Transformer + n_latent_blocks: int, default = 4 + Number of transformer encoder blocks (normalised MHA + normalised FF) + per Latent Transformer + n_perceiver_blocks: int, default = 4 + Number of Perceiver blocks defined as [Cross Attention + Latent + Transformer] + share_weights: Boolean, default = False + Boolean indicating if the weights will be shared between Perceiver + blocks + attn_dropout: float, default = 0.2 + Dropout that will be applied to the Multi-Head Attention layers + ff_dropout: float, default = 0.1 + Dropout that will be applied to the FeedForward network + transformer_activation: str, default = "gelu" + Transformer Encoder activation function. ``tanh``, ``relu``, + ``leaky_relu``, ``gelu``, ``geglu`` and ``reglu`` are supported + mlp_hidden_dims: List, Optional, default = None + MLP hidden dimensions. If not provided it will default to ``[l, 4*l, + 2*l]`` where ``l`` is the MLP input dimension + mlp_activation: str, default = "relu" + MLP activation function. ``tanh``, ``relu``, ``leaky_relu`` and + ``gelu`` are supported + mlp_dropout: float, default = 0.1 + Dropout that will be applied to the final MLP + mlp_batchnorm: bool, default = False + Boolean indicating whether or not to apply batch normalization to the + dense layers + mlp_batchnorm_last: bool, default = False + Boolean indicating whether or not to apply batch normalization to the + last of the dense layers + mlp_linear_first: bool, default = False + Boolean indicating whether the order of the operations in the dense + layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP -> + LIN -> ACT]`` + + Attributes + ---------- + cat_and_cont_embed: ``nn.Module`` + This is the module that processes the categorical and continuous columns + perceiver_blks: ``nn.ModuleDict`` + ModuleDict with the Perceiver blocks + latents: ``nn.Parameter`` + Latents that will be used for prediction + perceiver_mlp: ``nn.Module`` + MLP component in the model + output_dim: int + The output dimension of the model. This is a required attribute + neccesary to build the WideDeep class + + Example + -------- + >>> import torch + >>> from pytorch_widedeep.models import TabPerceiver + >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1) + >>> colnames = ['a', 'b', 'c', 'd', 'e'] + >>> embed_input = [(u,i) for u,i in zip(colnames[:4], [4]*4)] + >>> continuous_cols = ['e'] + >>> column_idx = {k:v for v,k in enumerate(colnames)} + >>> model = TabPerceiver(column_idx=column_idx, embed_input=embed_input, + ... continuous_cols=continuous_cols, n_latents=2, latent_dim=16, + ... n_perceiver_blocks=2) + >>> out = model(X_tab) + """ + + def __init__( + self, + column_idx: Dict[str, int], + embed_input: Optional[List[Tuple[str, int]]] = None, + embed_dropout: float = 0.1, + full_embed_dropout: bool = False, + shared_embed: bool = False, + add_shared_embed: bool = False, + frac_shared_embed: float = 0.25, + continuous_cols: Optional[List[str]] = None, + embed_continuous_activation: str = None, + cont_norm_layer: str = None, + input_dim: int = 32, + n_cross_attns: int = 1, + n_cross_attn_heads: int = 4, + n_latents: int = 16, + latent_dim: int = 128, + n_latent_heads: int = 4, + n_latent_blocks: int = 4, + n_perceiver_blocks: int = 4, + share_weights: bool = False, + attn_dropout: float = 0.1, + ff_dropout: float = 0.1, + transformer_activation: str = "geglu", + mlp_hidden_dims: Optional[List[int]] = None, + mlp_activation: str = "relu", + mlp_dropout: float = 0.1, + mlp_batchnorm: bool = False, + mlp_batchnorm_last: bool = False, + mlp_linear_first: bool = True, + ): + super(TabPerceiver, self).__init__() + + self.column_idx = column_idx + self.embed_input = embed_input + self.embed_dropout = embed_dropout + self.full_embed_dropout = full_embed_dropout + self.shared_embed = shared_embed + self.add_shared_embed = add_shared_embed + self.frac_shared_embed = frac_shared_embed + self.continuous_cols = continuous_cols + self.embed_continuous_activation = embed_continuous_activation + self.cont_norm_layer = cont_norm_layer + self.input_dim = input_dim + self.n_cross_attns = n_cross_attns + self.n_cross_attn_heads = n_cross_attn_heads + self.n_latents = n_latents + self.latent_dim = latent_dim + self.n_latent_heads = n_latent_heads + self.n_latent_blocks = n_latent_blocks + self.n_perceiver_blocks = n_perceiver_blocks + self.share_weights = share_weights + self.attn_dropout = attn_dropout + self.ff_dropout = ff_dropout + self.transformer_activation = transformer_activation + self.mlp_hidden_dims = mlp_hidden_dims + self.mlp_activation = mlp_activation + self.mlp_batchnorm = mlp_batchnorm + self.mlp_batchnorm_last = mlp_batchnorm_last + self.mlp_linear_first = mlp_linear_first + + if mlp_hidden_dims is not None: + assert ( + mlp_hidden_dims[0] == latent_dim + ), "The first mlp input dim must be equal to 'latent_dim'" + + # This should be named 'cat_and_cont_embed' since the continuous cols + # will always be embedded for the TabPerceiver. However is very + # convenient for other funcionalities to name + # it 'cat_and_cont_embed' + self.cat_and_cont_embed = CatAndContEmbeddings( + input_dim, + column_idx, + embed_input, + embed_dropout, + full_embed_dropout, + shared_embed, + add_shared_embed, + frac_shared_embed, + False, # use_embed_bias + continuous_cols, + True, # embed_continuous, + embed_continuous_activation, + True, # use_cont_bias + cont_norm_layer, + ) + + self.latents = nn.init.trunc_normal_( + nn.Parameter(torch.empty(n_latents, latent_dim)) + ) + + self.perceiver_blks = nn.ModuleDict() + first_perceiver_block = self._build_perceiver_block() + self.perceiver_blks["perceiver_block0"] = first_perceiver_block + + if share_weights: + for n in range(1, n_perceiver_blocks): + self.perceiver_blks["perceiver_block" + str(n)] = first_perceiver_block + else: + for n in range(1, n_perceiver_blocks): + self.perceiver_blks[ + "perceiver_block" + str(n) + ] = self._build_perceiver_block() + + if not mlp_hidden_dims: + self.mlp_hidden_dims = [latent_dim, latent_dim * 4, latent_dim * 2] + else: + assert mlp_hidden_dims[0] == latent_dim, ( + f"The input dim of the MLP must be {latent_dim}. " + f"Got {mlp_hidden_dims[0]} instead" + ) + self.perceiver_mlp = MLP( + self.mlp_hidden_dims, + mlp_activation, + mlp_dropout, + mlp_batchnorm, + mlp_batchnorm_last, + mlp_linear_first, + ) + + # the output_dim attribute will be used as input_dim when "merging" the models + self.output_dim = self.mlp_hidden_dims[-1] + + def forward(self, X: Tensor) -> Tensor: + + x_cat, x_cont = self.cat_and_cont_embed(X) + if x_cat is not None: + x_emb = x_cat + if x_cont is not None: + x_emb = torch.cat([x_emb, x_cont], 1) if x_cat is not None else x_cont + + x = einops.repeat(self.latents, "n d -> b n d", b=X.shape[0]) + + for n in range(self.n_perceiver_blocks): + cross_attns = self.perceiver_blks["perceiver_block" + str(n)]["cross_attns"] + latent_transformer = self.perceiver_blks["perceiver_block" + str(n)][ + "latent_transformer" + ] + for cross_attn in cross_attns: + x = cross_attn(x, x_emb) + x = latent_transformer(x) + + # average along the latent index axis + x = x.mean(dim=1) + + return self.perceiver_mlp(x) + + @property + def attention_weights(self) -> List: + r"""List with the attention weights. If the weights are not shared + between perceiver blocks each element of the list will be a list + itself containing the Cross Attention and Latent Transformer + attention weights respectively + + The shape of the attention weights is: + + - Cross Attention: :math:`(N, C, L, F)` + - Latent Attention: :math:`(N, T, L, L)` + + WHere *N* is the batch size, *C* is the number of Cross Attention + heads, *L* is the number of Latents, *F* is the number of + features/columns in the dataset and *T* is the number of Latent + Attention heads + """ + if self.share_weights: + cross_attns = self.perceiver_blks["perceiver_block0"]["cross_attns"] + latent_transformer = self.perceiver_blks["perceiver_block0"][ + "latent_transformer" + ] + attention_weights = self._extract_attn_weights( + cross_attns, latent_transformer + ) + else: + attention_weights = [] + for n in range(self.n_perceiver_blocks): + cross_attns = self.perceiver_blks["perceiver_block" + str(n)][ + "cross_attns" + ] + latent_transformer = self.perceiver_blks["perceiver_block" + str(n)][ + "latent_transformer" + ] + attention_weights.append( + self._extract_attn_weights(cross_attns, latent_transformer) + ) + return attention_weights + + def _build_perceiver_block(self) -> nn.ModuleDict: + + perceiver_block = nn.ModuleDict() + + # Cross Attention + cross_attns = nn.ModuleList() + for _ in range(self.n_cross_attns): + cross_attns.append( + PerceiverEncoder( + self.input_dim, + self.n_cross_attn_heads, + False, # use_bias + self.attn_dropout, + self.ff_dropout, + self.transformer_activation, + self.latent_dim, # q_dim, + ), + ) + perceiver_block["cross_attns"] = cross_attns + + # Latent Transformer + latent_transformer = nn.Sequential() + for i in range(self.n_latent_blocks): + latent_transformer.add_module( + "latent_block" + str(i), + PerceiverEncoder( + self.latent_dim, # input_dim + self.n_latent_heads, + False, # use_bias + self.attn_dropout, + self.ff_dropout, + self.transformer_activation, + ), + ) + perceiver_block["latent_transformer"] = latent_transformer + + return perceiver_block + + @staticmethod + def _extract_attn_weights(cross_attns, latent_transformer) -> List: + attention_weights = [] + for cross_attn in cross_attns: + attention_weights.append(cross_attn.attn.attn_weights) + for latent_block in latent_transformer: + attention_weights.append(latent_block.attn.attn_weights) + return attention_weights diff --git a/pytorch_widedeep/models/transformers/tab_transformer.py b/pytorch_widedeep/models/transformers/tab_transformer.py index 0d840dcbef05893333c6ac0808c5cf1cfd615279..4ea4b9c7a58d763259c01d01f288c48247216ecf 100644 --- a/pytorch_widedeep/models/transformers/tab_transformer.py +++ b/pytorch_widedeep/models/transformers/tab_transformer.py @@ -3,89 +3,93 @@ from torch import nn from pytorch_widedeep.wdtypes import * # noqa: F403 from pytorch_widedeep.models.tab_mlp import MLP -from pytorch_widedeep.models.transformers.layers import ( - SharedEmbeddings, - TransformerEncoder, - ContinuousEmbeddings, - FullEmbeddingDropout, +from pytorch_widedeep.models.transformers._encoders import TransformerEncoder +from pytorch_widedeep.models.transformers._embeddings_layers import ( + CatAndContEmbeddings, ) class TabTransformer(nn.Module): - r"""Adaptation of TabTransformer model - (https://arxiv.org/abs/2012.06678) model that can be used as the - deeptabular component of a Wide & Deep model. + r"""Defines a ``TabTransformer`` model + (`arXiv:2012.06678 `_) that can be used + as the ``deeptabular`` component of a Wide & Deep model. + + Note that this is an enhanced adaptation of the model described in the + original publication, containing a series of additional features. Parameters ---------- column_idx: Dict Dict containing the index of the columns that will be passed through - the DeepDense model. Required to slice the tensors. e.g. {'education': - 0, 'relationship': 1, 'workclass': 2, ...} + the model. Required to slice the tensors. e.g. + {'education': 0, 'relationship': 1, 'workclass': 2, ...} embed_input: List List of Tuples with the column name and number of unique values - e.g. [(education, 11), ...] + e.g. [('education', 11), ...] embed_dropout: float, default = 0.1 Dropout to be applied to the embeddings matrix full_embed_dropout: bool, default = False Boolean indicating if an entire embedding (i.e. the representation of one column) will be dropped in the batch. See: - :obj:`pytorch_widedeep.models.transformers.layers.FullEmbeddingDropout`. + :obj:`pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`. If ``full_embed_dropout = True``, ``embed_dropout`` is ignored. shared_embed: bool, default = False The idea behind ``shared_embed`` is described in the Appendix A in the paper: `'The goal of having column embedding is to enable the model to distinguish the classes in one column from those in the other columns'`. In other words, the idea - is to let the model learn which column is embedding at the time. - add_shared_embed: bool, default = False, + is to let the model learn which column is embedded at the time. + add_shared_embed: bool, default = False The two embedding sharing strategies are: 1) add the shared embeddings to the column embeddings or 2) to replace the first ``frac_shared_embed`` with the shared - embeddings. See :obj:`pytorch_widedeep.models.transformers.layers.SharedEmbeddings` + embeddings. See :obj:`pytorch_widedeep.models.transformers._layers.SharedEmbeddings` frac_shared_embed: float, default = 0.25 - The fraction of embeddings that will be shared by all the different categories for - one particular column. + The fraction of embeddings that will be shared (if ``add_shared_embed + = False``) by all the different categories for one particular + column. continuous_cols: List, Optional, default = None List with the name of the numeric (aka continuous) columns - embed_continuous: bool, default = False, + embed_continuous: bool, default = False Boolean indicating if the continuous features will be "embedded". See - ``pytorch_widedeep.models.transformers.layers.ContinuousEmbeddings`` - Note that setting this to true is equivalent to the so called - `FT-Transformer `_ - (Feature Tokenizer + Transformer). The only difference is that this - implementation does not consider using bias for the categorical - embeddings. + ``pytorch_widedeep.models.transformers._layers.ContinuousEmbeddings`` + Note that setting this to ``True`` is similar (but not identical) to the + so called `FT-Transformer `_ + (Feature Tokenizer + Transformer). + See :obj:`pytorch_widedeep.models.transformers.ft_transformer.FTTransformer` + for details on the dedicated implementation available in this + library embed_continuous_activation: str, default = None String indicating the activation function to be applied to the - continuous embeddings, if any. - 'relu', 'leaky_relu' and 'gelu' are supported. - cont_norm_layer: str, default = None, - Type of normalization layer applied to the continuous features. Options - are: 'layernorm', 'batchnorm' or None. + continuous embeddings, if any. ``tanh``, ``relu``, ``leaky_relu`` and + ``gelu`` are supported. + cont_norm_layer: str, default = "layernorm", + Type of normalization layer applied to the continuous features before + they are passed to the network. Options are: ``layernorm``, + ``batchnorm`` or ``None``. input_dim: int, default = 32 - The so-called *dimension of the model*. Is the number of embeddings used to encode - the categorical columns + The so-called *dimension of the model*. In general is the number of + embeddings used to encode the categorical and/or continuous columns n_heads: int, default = 8 Number of attention heads per Transformer block - n_blocks: int, default = 6 + use_bias: bool, default = False + Boolean indicating whether or not to use bias in the Q, K, and V + projection layers. + n_blocks: int, default = 4 Number of Transformer blocks - dropout: float, default = 0.1 - Dropout that will be applied internally to the ``TransformerEncoder`` - (see :obj:`pytorch_widedeep.models.transformers.layers.TransformerEncoder`) - and the output MLP - keep_attn_weights: bool, default = False - If set to ``True`` the model will store the attention weights in the ``attention_weights`` - attribute. - ff_hidden_dim: int, default = 128 - Hidden dimension of the ``FeedForward`` Layer. See - :obj:`pytorch_widedeep.models.transformers.layers.FeedForward`. + attn_dropout: float, default = 0.2 + Dropout that will be applied to the Multi-Head Attention layers + ff_dropout: float, default = 0.1 + Dropout that will be applied to the FeedForward network transformer_activation: str, default = "gelu" - Transformer Encoder activation function. 'relu', 'leaky_relu', 'gelu' - and 'geglu' are supported + Transformer Encoder activation function. ``tanh``, ``relu``, + ``leaky_relu``, ``gelu``, ``geglu`` and ``reglu`` are supported mlp_hidden_dims: List, Optional, default = None - MLP hidden dimensions. If not provided it will default to ``[4*l, - 2*l]`` where ``l`` is the mlp input dimension + MLP hidden dimensions. If not provided it will default to ``[l, 4*l, + 2*l]`` where ``l`` is the MLP input dimension mlp_activation: str, default = "relu" - MLP activation function. 'relu', 'leaky_relu' and 'gelu' are supported + MLP activation function. ``tanh``, ``relu``, ``leaky_relu`` and + ``gelu`` are supported + mlp_dropout: float, default = 0.1 + Dropout that will be applied to the final MLP mlp_batchnorm: bool, default = False Boolean indicating whether or not to apply batch normalization to the dense layers @@ -99,17 +103,10 @@ class TabTransformer(nn.Module): Attributes ---------- - cat_embed_layers: ``nn.ModuleDict`` - Dict with the embeddings per column - cont_embed: ``nn.Module`` - Continuous embeddings layer if ``embed_continuous=True``. See - ``pytorch_widedeep.models.transformers.layers.ContinuousEmbeddings`` - cont_norm: ``nn.Module`` - continuous normalization layer + cat_and_cont_embed: ``nn.Module`` + This is the module that processes the categorical and continuous columns transformer_blks: ``nn.Sequential`` Sequence of Transformer blocks - attention_weights: List - List with the attention weights per block if ``keep_attn_weights = True``. transformer_mlp: ``nn.Module`` MLP component in the model output_dim: int @@ -132,7 +129,7 @@ class TabTransformer(nn.Module): def __init__( self, column_idx: Dict[str, int], - embed_input: List[Tuple[str, int]], + embed_input: Optional[List[Tuple[str, int]]] = None, embed_dropout: float = 0.1, full_embed_dropout: bool = False, shared_embed: bool = False, @@ -144,13 +141,14 @@ class TabTransformer(nn.Module): cont_norm_layer: str = None, input_dim: int = 32, n_heads: int = 8, - n_blocks: int = 6, - dropout: float = 0.1, - keep_attn_weights: bool = False, - ff_hidden_dim: int = 32 * 4, + use_bias: bool = False, + n_blocks: int = 4, + attn_dropout: float = 0.2, + ff_dropout: float = 0.1, transformer_activation: str = "gelu", mlp_hidden_dims: Optional[List[int]] = None, mlp_activation: str = "relu", + mlp_dropout: float = 0.1, mlp_batchnorm: bool = False, mlp_batchnorm_last: bool = False, mlp_linear_first: bool = True, @@ -170,10 +168,10 @@ class TabTransformer(nn.Module): self.cont_norm_layer = cont_norm_layer self.input_dim = input_dim self.n_heads = n_heads + self.use_bias = use_bias self.n_blocks = n_blocks - self.dropout = dropout - self.keep_attn_weights = keep_attn_weights - self.ff_hidden_dim = ff_hidden_dim + self.attn_dropout = attn_dropout + self.ff_dropout = ff_dropout self.transformer_activation = transformer_activation self.mlp_hidden_dims = mlp_hidden_dims self.mlp_activation = mlp_activation @@ -181,36 +179,62 @@ class TabTransformer(nn.Module): self.mlp_batchnorm_last = mlp_batchnorm_last self.mlp_linear_first = mlp_linear_first - self.with_cls_token = "cls_token" in self.column_idx - self.categorical_cols = [ei[0] for ei in self.embed_input] - self.n_tokens = sum([ei[1] for ei in self.embed_input]) + self.with_cls_token = "cls_token" in column_idx + self.n_cat = len(embed_input) if embed_input is not None else 0 + self.n_cont = len(continuous_cols) if continuous_cols is not None else 0 - self._set_categ_embeddings() + if self.n_cont and not self.n_cat and not self.embed_continuous: + raise ValueError( + "If only continuous features are used 'embed_continuous' must be set to 'True'" + ) - self._set_cont_cols() + self.cat_and_cont_embed = CatAndContEmbeddings( + input_dim, + column_idx, + embed_input, + embed_dropout, + full_embed_dropout, + shared_embed, + add_shared_embed, + frac_shared_embed, + False, # use_embed_bias + continuous_cols, + embed_continuous, + embed_continuous_activation, + True, # use_cont_bias + cont_norm_layer, + ) self.transformer_blks = nn.Sequential() for i in range(n_blocks): self.transformer_blks.add_module( - "block" + str(i), + "transformer_block" + str(i), TransformerEncoder( input_dim, n_heads, - keep_attn_weights, - ff_hidden_dim, - dropout, + use_bias, + attn_dropout, + ff_dropout, transformer_activation, ), ) - if keep_attn_weights: - self.attention_weights: List[Any] = [None] * n_blocks + attn_output_dim = self._compute_attn_output_dim() if not mlp_hidden_dims: - mlp_hidden_dims = self._set_mlp_hidden_dims() + mlp_hidden_dims = [ + attn_output_dim, + attn_output_dim * 4, + attn_output_dim * 2, + ] + else: + assert mlp_hidden_dims[0] == attn_output_dim, ( + f"The input dim of the MLP must be {attn_output_dim}. " + f"Got {mlp_hidden_dims[0]} instead" + ) self.transformer_mlp = MLP( mlp_hidden_dims, mlp_activation, - dropout, + mlp_dropout, mlp_batchnorm, mlp_batchnorm_last, mlp_linear_first, @@ -221,115 +245,48 @@ class TabTransformer(nn.Module): def forward(self, X: Tensor) -> Tensor: - if self.shared_embed: - x_cat_embed = [ - self.cat_embed["emb_layer_" + col]( - X[:, self.column_idx[col]].long() - ).unsqueeze(1) - for col, _ in self.embed_input - ] - x = torch.cat(x_cat_embed, 1) - else: - x = self.cat_embed(X[:, self.cat_idx].long()) - - if not self.shared_embed and self.embedding_dropout is not None: - x = self.embedding_dropout(x) + x_cat, x_cont = self.cat_and_cont_embed(X) - if self.continuous_cols is not None and self.embed_continuous: - x_cont = self.cont_norm((X[:, self.cont_idx].float())) - x_cont_embed = self.cont_embed(x_cont) - x = torch.cat([x, x_cont_embed], 1) + if x_cat is not None: + x = x_cat + if x_cont is not None and self.embed_continuous: + x = torch.cat([x, x_cont], 1) if x_cat is not None else x_cont - for i, blk in enumerate(self.transformer_blks): - x = blk(x) - if self.keep_attn_weights: - if hasattr(blk, "row_attn"): - self.attention_weights[i] = ( - blk.self_attn.attn_weights, - blk.row_attn.attn_weights, - ) - else: - self.attention_weights[i] = blk.self_attn.attn_weights + x = self.transformer_blks(x) if self.with_cls_token: x = x[:, 0, :] else: x = x.flatten(1) - if self.continuous_cols is not None and not self.embed_continuous: - x_cont = self.cont_norm((X[:, self.cont_idx].float())) + if x_cont is not None and not self.embed_continuous: x = torch.cat([x, x_cont], 1) return self.transformer_mlp(x) - def _set_categ_embeddings(self): - self.cat_idx = [self.column_idx[col] for col in self.categorical_cols] - # Categorical: val + 1 because 0 is reserved for padding/unseen cateogories. - if self.shared_embed: - self.cat_embed = nn.ModuleDict( - { - "emb_layer_" - + col: SharedEmbeddings( - val if col == "cls_token" else val + 1, - self.input_dim, - self.embed_dropout, - self.full_embed_dropout, - self.add_shared_embed, - self.frac_shared_embed, - ) - for col, val in self.embed_input - } - ) - else: - self.cat_embed = nn.Embedding( - self.n_tokens + 1, self.input_dim, padding_idx=0 - ) - if self.full_embed_dropout: - self.embedding_dropout: DropoutLayers = FullEmbeddingDropout( - self.embed_dropout - ) - else: - self.embedding_dropout = nn.Dropout(self.embed_dropout) + @property + def attention_weights(self) -> List: + r"""List with the attention weights - def _set_cont_cols(self): - if self.continuous_cols is not None: - self.cont_idx = [self.column_idx[col] for col in self.continuous_cols] - if self.cont_norm_layer == "layernorm": - self.cont_norm: NormLayers = nn.LayerNorm(len(self.continuous_cols)) - elif self.cont_norm_layer == "batchnorm": - self.cont_norm = nn.BatchNorm1d(len(self.continuous_cols)) - else: - self.cont_norm = nn.Identity() - if self.embed_continuous: - self.cont_embed = ContinuousEmbeddings( - len(self.continuous_cols), - self.input_dim, - self.embed_continuous_activation, - ) + The shape of the attention weights is: + + :math:`(N, H, F, F)` - def _set_mlp_hidden_dims(self) -> List[int]: - if self.continuous_cols is not None: - if self.with_cls_token: - if self.embed_continuous: - mlp_hidden_dims = [ - self.input_dim, - self.input_dim * 4, - self.input_dim * 2, - ] - else: - mlp_inp_l = self.input_dim + len(self.continuous_cols) - mlp_hidden_dims = [mlp_inp_l, mlp_inp_l * 4, mlp_inp_l * 2] - elif self.embed_continuous: - mlp_inp_l = ( - len(self.embed_input) + len(self.continuous_cols) - ) * self.input_dim - mlp_hidden_dims = [mlp_inp_l, mlp_inp_l * 4, mlp_inp_l * 2] + Where *N* is the batch size, *H* is the number of attention heads + and *F* is the number of features/columns in the dataset + """ + return [blk.attn.attn_weights for blk in self.transformer_blks] + + def _compute_attn_output_dim(self) -> int: + + if self.with_cls_token: + if self.embed_continuous: + attn_output_dim = self.input_dim else: - mlp_inp_l = len(self.embed_input) * self.input_dim + len( - self.continuous_cols - ) - mlp_hidden_dims = [mlp_inp_l, mlp_inp_l * 4, mlp_inp_l * 2] + attn_output_dim = self.input_dim + self.n_cont + elif self.embed_continuous: + attn_output_dim = (self.n_cat + self.n_cont) * self.input_dim else: - mlp_inp_l = len(self.embed_input) * self.input_dim - mlp_hidden_dims = [mlp_inp_l, mlp_inp_l * 4, mlp_inp_l * 2] - return mlp_hidden_dims + attn_output_dim = self.n_cat * self.input_dim + self.n_cont + + return attn_output_dim diff --git a/pytorch_widedeep/models/wide_deep.py b/pytorch_widedeep/models/wide_deep.py index 08f94399c78a13d2b6944656966511781ba37dfd..2c86b85ad950e3b1a12e279d3a58d932c4abb0c0 100644 --- a/pytorch_widedeep/models/wide_deep.py +++ b/pytorch_widedeep/models/wide_deep.py @@ -1,3 +1,15 @@ +""" +During the development of the package I realised that there is a typing +inconsistency. The input components of a Wide and Deep model are of type +nn.Module. These change type internally to nn.Sequential. While nn.Sequential +is an instance of nn.Module the oppossite is, of course, not true. This does +not affect any funcionality of the package, but it is something that needs +fixing. However, while fixing is simple (simply define new attributes that +are the nn.Sequential objects), its implications are quite wide within the +package (involves changing a number of tests and tutorials). Therefore, I +will introduce that fix when I do a major release. For now, we live with it. +""" + import warnings import torch @@ -30,51 +42,16 @@ class WideDeep(nn.Module): Parameters ---------- wide: ``nn.Module``, Optional, default = None - ``Wide`` model. I recommend using the :obj:`Wide` class in this + ``Wide`` model. I recommend using the ``Wide`` class in this package. However, it is possible to use a custom model as long as is consistent with the required architecture, see :class:`pytorch_widedeep.models.wide.Wide` deeptabular: ``nn.Module``, Optional, default = None - - currently ``pytorch-widedeep`` implements four possible - architectures for the `deeptabular` component. These are: - TabMlp, TabResnet, TabNet, TabTransformer and SAINT. - - 1. TabMlp is simply an embedding layer encoding the categorical - features that are then concatenated and passed through a series of - dense (hidden) layers (i.e. and MLP). - See: :obj:`pytorch_widedeep.models.tab_mlp.TabMlp` - - 2. TabResnet is an embedding layer encoding the categorical - features that are then concatenated and passed through a series of - ResNet blocks formed by dense layers. - See :obj:`pytorch_widedeep.models.tab_resnet.TabResnet` - - 3. TabNet is detailed in `TabNet: Attentive Interpretable Tabular - Learning `_. The TabNet - implementation in ``pytorch_widedeep`` is an adaptation of the - `dreamquark-ai `_ - implementation. See - :obj:`pytorch_widedeep.models.tabnet.tab_net.TabNet` - - 3. TabTransformer is detailed in `TabTransformer: Tabular Data - Modeling Using Contextual Embeddings - `_. The TabTransformer - implementation in ``pytorch-widedeep`` is an adaptation of the - original implementation. See - :obj:`pytorch_widedeep.models.transformers.tab_transformer.TabTransformer`. - - 3. SAINT is detailed in `SAINT: Improved Neural Networks for Tabular - Data via Row Attention and Contrastive Pre-Training - `_. The SAINT implementation in - ``pytorch-widedeep`` is an adaptation of the original implementation. - See - :obj:`pytorch_widedeep.models.transformers.saint.SAINT`. - - I recommend using on of these as ``deeptabular``. However, it is - possible to use a custom model as long as is consistent with the - required architecture. - + currently ``pytorch-widedeep`` implements a number of possible + architectures for the ``deeptabular`` component. See the documenation + of the package. I recommend using the ``deeptabular`` components in + this package. However, it is possible to use a custom model as long + as is consistent with the required architecture. deeptext: ``nn.Module``, Optional, default = None Model for the text input. Must be an object of class ``DeepText`` or a custom model as long as is consistent with the required @@ -97,8 +74,8 @@ class WideDeep(nn.Module): If ``head_hidden_dims`` is not None, dropout between the layers in ``head_hidden_dims`` head_activation: str, default = "relu" - If ``head_hidden_dims`` is not None, activation function of the - head layers. One of "relu", gelu" or "leaky_relu" + If ``head_hidden_dims`` is not None, activation function of the head + layers. One of ``tanh``, ``relu``, ``gelu`` or ``leaky_relu`` head_batchnorm: bool, default = False If ``head_hidden_dims`` is not None, specifies if batch normalizatin should be included in the head layers diff --git a/pytorch_widedeep/preprocessing/tab_preprocessor.py b/pytorch_widedeep/preprocessing/tab_preprocessor.py index c914d99a432b1c323aa048237b7f862c9b7be021..0cc308911bb990b4ff95981bf3e83ad4f0031321 100644 --- a/pytorch_widedeep/preprocessing/tab_preprocessor.py +++ b/pytorch_widedeep/preprocessing/tab_preprocessor.py @@ -30,12 +30,12 @@ class TabPreprocessor(BasePreprocessor): continuous_cols: List, default = None List with the name of the so called continuous cols scale: bool, default = True - Bool indicating whether or not to scale/standarise continuous - cols. The user should bear in mind that all the ``deeptabular`` - components available within ``pytorch-widedeep`` they also include - the possibility of normalising the input continuous features via a + Bool indicating whether or not to scale/standarise continuous cols. + The user should bear in mind that all the ``deeptabular`` components + available within ``pytorch-widedeep`` they also include the + possibility of normalising the input continuous features via a ``BatchNorm`` or a ``LayerNorm`` layer. See - :class:`pytorch_widedeep.models` + :obj:`pytorch_widedeep.models.transformers._embedding_layers` auto_embed_dim: bool, default = True Boolean indicating whether the embedding dimensions will be automatically defined via fastai's rule of thumb': @@ -53,26 +53,27 @@ class TabPreprocessor(BasePreprocessor): tabular library) and not standarize them any further for_transformer: bool, default = False Boolean indicating whether the preprocessed data will be passed to a - transformer-based model (i.e. ``TabTransformer`` or ``SAINT``). If - ``True``, the param ``embed_cols`` must just be a list containing the - categorical columns: e.g.:['education', 'relationship', ...] This is - because they will all be encoded using embeddings of the same dim - (32 by default). + transformer-based model + (See :obj:`pytorch_widedeep.models.transformers`). If ``True``, the + param ``embed_cols`` must just be a list containing the categorical + columns: e.g.:['education', 'relationship', ...] This is because they + will all be encoded using embeddings of the same dim. with_cls_token: bool, default = False Boolean indicating if a `'[CLS]'` token will be added to the dataset - when using transformer-based models (i.e. ``TabTransformer`` or - ``SAINT``). The final hidden state corresponding to this token is - used as the aggregate row representation for classification and - regression tasks. If not, the categorical (and continuous embeddings - if present) will be concatenated before being passed to the final - MLP. + when using transformer-based models. The final hidden state + corresponding to this token is used as the aggregated representation + for classification and regression tasks. If not, the categorical + (and continuous embeddings if present) will be concatenated before + being passed to the final MLP. shared_embed: bool, default = False - This parameter will only be used by the ``TabPreprocessor`` when the - data is being prepapred for a transformer-based model. If that is the - case and the embeddings are 'shared' - (see: - ``pytorch_widedeep.models.transformers.layers.SharedEmbeddings``) - then each column will be embed indepedently. + Boolean indicating if the embeddings will be "shared" when using + transformer-based models. The idea behind ``shared_embed`` is + described in the Appendix A in the `TabTransformer paper + `_: `'The goal of having column + embedding is to enable the model to distinguish the classes in one + column from those in the other columns'`. In other words, the idea is + to let the model learn which column is embedded at the time. See: + :obj:`pytorch_widedeep.models.transformers._layers.SharedEmbeddings`. verbose: int, default = 1 Attributes diff --git a/pytorch_widedeep/tab2vec.py b/pytorch_widedeep/tab2vec.py new file mode 100644 index 0000000000000000000000000000000000000000..bd46305ada6bf8534a12fd59e5bb28182c46bcbc --- /dev/null +++ b/pytorch_widedeep/tab2vec.py @@ -0,0 +1,167 @@ +import warnings +from copy import deepcopy + +import numpy as np +import torch +import einops +import pandas as pd + +from pytorch_widedeep.wdtypes import * # noqa: F403 +from pytorch_widedeep.preprocessing import TabPreprocessor + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Tab2Vec: + r"""Class to transform an input dataframe into vectorized form. + + This class will take an input dataframe in the form of the dataframe used + for training, and it will turn it into a vectorised form based on the + processing applied to the categorical and continuous columns. + + .. note:: Currently this class is only implemented for the deeptabular component. + Therefore, if the input dataframe has a text column or a column with + the path to images, these will be ignored. We will be adding these + functionalities in future versions + + Parameters + ---------- + model: ``WideDeep`` + ``WideDeep`` model. Must be trained. + See :obj:`pytorch-widedeep.models.wide_deep.WideDeep` + tab_preprocessor: ``TabPreprocessor`` + ``TabPreprocessor`` object. Must be fitted. + See :obj:`pytorch-widedeep.preprocessing.tab_preprocessor.TabPreprocessor` + + Attributes + ---------- + vectorizer: ``nn.Module`` + Torch module with the categorical and continuous encoding process + + Examples + -------- + >>> import string + >>> from random import choices + >>> import numpy as np + >>> import pandas as pd + >>> from pytorch_widedeep import Tab2Vec + >>> from pytorch_widedeep.models import TabMlp, WideDeep + >>> from pytorch_widedeep.preprocessing import TabPreprocessor + >>> + >>> colnames = list(string.ascii_lowercase)[:4] + >>> cat_col1_vals = ["a", "b", "c"] + >>> cat_col2_vals = ["d", "e", "f"] + >>> + >>> # Create the toy input dataframe and a toy dataframe to be vectorised + >>> cat_inp = [np.array(choices(c, k=5)) for c in [cat_col1_vals, cat_col2_vals]] + >>> cont_inp = [np.round(np.random.rand(5), 2) for _ in range(2)] + >>> df_inp = pd.DataFrame(np.vstack(cat_inp + cont_inp).transpose(), columns=colnames) + >>> cat_t2v = [np.array(choices(c, k=5)) for c in [cat_col1_vals, cat_col2_vals]] + >>> cont_t2v = [np.round(np.random.rand(5), 2) for _ in range(2)] + >>> df_t2v = pd.DataFrame(np.vstack(cat_t2v + cont_t2v).transpose(), columns=colnames) + >>> + >>> # fit the TabPreprocessor + >>> embed_cols = [("a", 2), ("b", 4)] + >>> cont_cols = ["c", "d"] + >>> tab_preprocessor = TabPreprocessor(embed_cols=embed_cols, continuous_cols=cont_cols) + >>> X_tab = tab_preprocessor.fit_transform(df_inp) + >>> + >>> # define the model (and let's assume we train it) + >>> tabmlp = TabMlp( + ... column_idx=tab_preprocessor.column_idx, + ... embed_input=tab_preprocessor.embeddings_input, + ... continuous_cols=tab_preprocessor.continuous_cols, + ... mlp_hidden_dims=[8, 4]) + >>> model = WideDeep(deeptabular=tabmlp) + >>> # ...train the model... + >>> + >>> # vectorise the dataframe + >>> t2v = Tab2Vec(model, tab_preprocessor) + >>> X_vec = t2v.transform(df_t2v) + """ + + def __init__( + self, model: WideDeep, tab_preprocessor: TabPreprocessor, verbose: bool = False + ): + super(Tab2Vec, self).__init__() + + if verbose: + if model.deepimage is not None or model.deeptext is not None: + warnings.warn( + "Currently 'Tab2Vec' is only implemented for the 'deeptabular' component." + ) + + if model.deeptabular is None: + raise RuntimeError( + "Currently 'Tab2Vec' is only implemented for the 'deeptabular' component." + ) + if not tab_preprocessor.is_fitted: + raise RuntimeError( + "The 'tab_preprocessor' must be fitted before is passed to 'Tab2Vec'" + ) + + self.tab_preprocessor = tab_preprocessor + + transformer_family = [ + "tabtransformer", + "saint", + "fttransformer", + "tabperceiver", + "tabfastformer", + ] + self.is_transformer = ( + model.deeptabular[0].__class__.__name__.lower() in transformer_family # type: ignore[index] + ) + self.vectorizer = ( + deepcopy(model.deeptabular[0].cat_embed_and_cont) # type: ignore[index] + if not self.is_transformer + else deepcopy(model.deeptabular[0].cat_and_cont_embed) # type: ignore[index] + ) + self.vectorizer.to(device) + + def fit(self, df: pd.DataFrame, target_col: Optional[str] = None) -> "Tab2Vec": + r"""Empty method. Returns the object itself. Is only included for + consistency in case ``Tab2Vec`` is used as part of a Pipeline + """ + return self + + def transform( + self, df: pd.DataFrame, target_col: Optional[str] = None + ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + r""" + Parameters + ---------- + df: pd.DataFrame + DataFrame to be vectorised, i.e. the categorical and continuous + columns will be encoded based on the processing applied within + the model + target_col: str, Optional + Column name of the target_col variable. If 'None' only the array of + predictors will be returned + """ + + X_tab = self.tab_preprocessor.transform(df) + X = torch.from_numpy(X_tab).to(device) + + with torch.no_grad(): + x_cat, x_cont = self.vectorizer(X) + + if self.tab_preprocessor.with_cls_token: + x_cat = x_cat[:, 1:, :] + + if self.is_transformer: + x_cat = einops.rearrange(x_cat, "s c e -> s (c e)") + if len(list(x_cont.shape)) == 3: + x_cont = einops.rearrange(x_cont, "s c e -> s (c e)") + + X_vec = torch.cat([x_cat, x_cont], 1).cpu().numpy() + if target_col: + return X_vec, df[target_col].values + else: + return X_vec + + def fit_transform( + self, df: pd.DataFrame, target_col: Optional[str] = None + ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + r"""Combines ``fit`` and ``transform``""" + return self.fit(df, target_col).transform(df, target_col) diff --git a/pytorch_widedeep/training/_loss_and_obj_aliases.py b/pytorch_widedeep/training/_loss_and_obj_aliases.py index 6fd36fdb5a52b4f3b57ed9bef5155c3190cc7ee4..0c9747e38ecf279024d23fee8ead4a6fe877ca66 100644 --- a/pytorch_widedeep/training/_loss_and_obj_aliases.py +++ b/pytorch_widedeep/training/_loss_and_obj_aliases.py @@ -7,7 +7,7 @@ class classproperty: @classmethod @property - Given that we support 3.6, 3.7 and 3.8 let's use this hack + Given that we support 3.7, 3.8 as well as 3.9, let's use this hack """ def __init__(self, func): diff --git a/pytorch_widedeep/training/trainer_utils.py b/pytorch_widedeep/training/_trainer_utils.py similarity index 68% rename from pytorch_widedeep/training/trainer_utils.py rename to pytorch_widedeep/training/_trainer_utils.py index 716cc4ad088cf2140957c71f9d7f27a5df307135..2c63ec9280c64e5a2cfd727aa6c4c319c9206f27 100644 --- a/pytorch_widedeep/training/trainer_utils.py +++ b/pytorch_widedeep/training/_trainer_utils.py @@ -1,24 +1,10 @@ -""" -Code for 'Alias' and 'set_default_attr' taken from the one and only Hunter -McGushion and his library: -https://github.com/HunterMcGushion/hyperparameter_hunter -""" - import numpy as np -import wrapt from tqdm import tqdm from torch import nn from sklearn.model_selection import train_test_split from pytorch_widedeep.losses import MSLELoss, RMSELoss, FocalLoss, RMSLELoss -from pytorch_widedeep.wdtypes import ( - Any, - Dict, - List, - Union, - Optional, - Transforms, -) +from pytorch_widedeep.wdtypes import Dict, List, Optional, Transforms from pytorch_widedeep.training._wd_dataset import WideDeepDataset from pytorch_widedeep.training._loss_and_obj_aliases import ( _LossAliases, @@ -206,7 +192,7 @@ def alias_to_loss(loss_fn: str, **kwargs): Examples -------- - >>> from pytorch_widedeep.training.trainer_utils import alias_to_loss + >>> from pytorch_widedeep.training._trainer_utils import alias_to_loss >>> loss_fn = alias_to_loss(loss_fn="binary_logloss", weight=None) """ if loss_fn not in _ObjectiveToMethod.keys(): @@ -231,93 +217,3 @@ def alias_to_loss(loss_fn: str, **kwargs): return RMSLELoss() if "focal_loss" in loss_fn: return FocalLoss(**kwargs) - - -class Alias: - def __init__(self, primary_name: str, aliases: Union[str, List[str]]): - r"""Convert uses of `aliases` to `primary_name` upon calling the decorated - function/method - - Parameters - ---------- - primary_name: String - Preferred name for the parameter, the value of which will be set - to the value of the used alias. If `primary_name` is already - explicitly used on call in addition to any aliases, the value of - `primary_name` will remain unchanged. It only assumes the value of - an alias if the `primary_name` is not used - aliases: List, string - One or multiple string aliases for `primary_name`. If - `primary_name` is not used on call, its value will be set to that - of a random alias in `aliases`. Before calling the decorated - callable, all `aliases` are removed from its kwargs - - Examples - -------- - >>> class Foo(): - ... @Alias("a", ["a2"]) - ... def __init__(self, a, b=None): - ... print(a, b) - >>> @Alias("a", ["a2"]) - ... @Alias("b", ["b2"]) - ... def bar(a, b=None): - ... print(a, b) - >>> foo = Foo(a2="x", b="y") - x y - >>> bar(a2="x", b2="y") - x y""" - self.primary_name = primary_name - self.aliases = aliases if isinstance(aliases, list) else [aliases] - - @wrapt.decorator - def __call__(self, wrapped, instance, args, kwargs): - for alias in set(self.aliases).intersection(kwargs): - # Only set if no `primary_name` already. Remove `aliases`, leaving only `primary_name` - kwargs.setdefault(self.primary_name, kwargs.pop(alias)) - # Record aliases used in `instance.__wd_aliases_used` or `wrapped.__wd_aliases_used` - if instance: - set_default_attr(instance, "__wd_aliases_used", {})[ - self.primary_name - ] = alias - else: - set_default_attr(wrapped, "__wd_aliases_used", {})[ - self.primary_name - ] = alias - return wrapped(*args, **kwargs) - - -def set_default_attr(obj: Any, name: str, value: Any): - r"""Set the `name` attribute of `obj` to `value` if the attribute does not - already exist - - Parameters - ---------- - obj: Object - Object whose `name` attribute will be returned (after setting it to - `value`, if necessary) - name: String - Name of the attribute to set to `value`, or to return - value: Object - Default value to give to `obj.name` if the attribute does not already - exist - - Returns - ------- - Object - `obj.name` if it exists. Else, `value` - - Examples - -------- - >>> foo = type("Foo", tuple(), {"my_attr": 32}) - >>> set_default_attr(foo, "my_attr", 99) - 32 - >>> set_default_attr(foo, "other_attr", 9000) - 9000 - >>> assert foo.my_attr == 32 - >>> assert foo.other_attr == 9000 - """ - try: - return getattr(obj, name) - except AttributeError: - setattr(obj, name, value) - return value diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index 44c37d1aac213c72538ce0cc7851154a686c9204..4334ac9e475c130df68e022f94416497cde01235 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -1,5 +1,6 @@ import os import json +import warnings from pathlib import Path import numpy as np @@ -24,15 +25,15 @@ from pytorch_widedeep.callbacks import ( from pytorch_widedeep.dataloaders import DataLoaderDefault from pytorch_widedeep.initializers import Initializer, MultipleInitializer from pytorch_widedeep.training._finetune import FineTune +from pytorch_widedeep.utils.general_utils import Alias +from pytorch_widedeep.models.tabnet._utils import create_explain_matrix from pytorch_widedeep.training._wd_dataset import WideDeepDataset -from pytorch_widedeep.training.trainer_utils import ( - Alias, +from pytorch_widedeep.training._trainer_utils import ( alias_to_loss, save_epoch_logs, wd_train_val_split, print_loss_and_metric, ) -from pytorch_widedeep.models.tabnet.tab_net_utils import create_explain_matrix from pytorch_widedeep.training._multiple_optimizer import MultipleOptimizer from pytorch_widedeep.training._multiple_transforms import MultipleTransforms from pytorch_widedeep.training._loss_and_obj_aliases import _ObjectiveToMethod @@ -47,7 +48,7 @@ device = torch.device("cuda" if use_cuda else "cpu") class Trainer: - r"""Method to set the of attributes that will be used during the + r"""Class to set the of attributes that will be used during the training process. Parameters @@ -152,7 +153,8 @@ class Trainer: need to be normalised. See `this discussion `_. lambda_sparse: float. default=1e-3 - Tabnet sparse regularization factor + Tabnet sparse regularization factor. Used, of course, if the + ``deeptabular`` component is a Tabnet model alpha: float. default=0.25 if ``objective`` is ``binary_focal_loss`` or ``multiclass_focal_loss``, the Focal Loss alpha and gamma @@ -749,6 +751,9 @@ class Trainer: r"""Returns the learned embeddings for the categorical features passed through ``deeptabular``. + .. note:: This function will be deprecated in the next relase. Please consider + using ``Tab2Vec`` instead. + This method is designed to take an encoding dictionary in the same format as that of the :obj:`LabelEncoder` Attribute in the class :obj:`TabPreprocessor`. See @@ -783,6 +788,11 @@ class Trainer: trainer.get_embeddings(col_name="education", cat_encoding_dict=encoding_dict) """ + warnings.warn( + "'get_embeddings' will be deprecated in the next release. " + "Please consider using 'Tab2vec' instead", + DeprecationWarning, + ) for n, p in self.model.named_parameters(): if "embed_layers" in n and col_name in n: embed_mtx = p.cpu().data.numpy() @@ -795,9 +805,10 @@ class Trainer: def explain(self, X_tab: np.ndarray, save_step_masks: bool = False): """ - Returns the aggregated feature importance for each instance (or - observation) in the ``X_tab`` array. If ``save_step_masks`` is set to - ``True``, the masks per step will also be returned. + if the ``deeptabular`` component is a Tabnet model, returns the + aggregated feature importance for each instance (or observation) in + the ``X_tab`` array. If ``save_step_masks`` is set to ``True``, the + masks per step will also be returned. Parameters ---------- @@ -874,9 +885,9 @@ class Trainer: The exception is Tabnet. If the ``deeptabular`` component is a Tabnet model, an attribute (a dict) called ``feature_importance`` will be created at the end of the training process. Therefore, a ``save`` - method was created that will save both the feature importance - dictionary to a json file and, since we are here, the model weights, - training history and learning rate history. + method was created that will save the feature importance dictionary + to a json file and, since we are here, the model weights, training + history and learning rate history. Parameters ---------- diff --git a/pytorch_widedeep/utils/deeptabular_utils.py b/pytorch_widedeep/utils/deeptabular_utils.py index cb9cec78677210d86c02474a482143edb60750fd..b7c672de5e0167af6ccaa215a2b6747996f1aedb 100644 --- a/pytorch_widedeep/utils/deeptabular_utils.py +++ b/pytorch_widedeep/utils/deeptabular_utils.py @@ -25,13 +25,17 @@ class LabelEncoder: encoded. for_transformer: bool, default = False Boolean indicating whether the preprocessed data will be passed to a - transformer-based model (i.e. ``TabTransformer`` or ``SAINT``). + transformer-based model. + See :obj:`pytorch_widedeep.models.transformers` shared_embed: bool, default = False Boolean indicating if the embeddings will be "shared" when using - transformer-based models - (see: - ``pytorch_widedeep.models.transformers.layers.SharedEmbeddings``) - then each column will be embed indepedently. + transformer-based models. The idea behind ``shared_embed`` is + described in the Appendix A in the `TabTransformer paper + `_: `'The goal of having column + embedding is to enable the model to distinguish the classes in one + column from those in the other columns'`. In other words, the idea is + to let the model learn which column is embedded at the time. See: + :obj:`pytorch_widedeep.models.transformers._layers.SharedEmbeddings`. Attributes ----------- diff --git a/pytorch_widedeep/utils/fastai_transforms.py b/pytorch_widedeep/utils/fastai_transforms.py index 4814dacf5db3e5ffd65aa62d7eda4559bb8f7381..eb86e3da32f37d26f3656a50bb5086de0550faa9 100644 --- a/pytorch_widedeep/utils/fastai_transforms.py +++ b/pytorch_widedeep/utils/fastai_transforms.py @@ -220,15 +220,15 @@ class Tokenizer: ---------- tok_func: Callable, default = ``SpacyTokenizer`` Tokenizer Object. See :class:`pytorch_widedeep.utils.fastai_transforms.SpacyTokenizer` - lang: str, default = "en", + lang: str, default = "en" Text's Language - pre_rules: ListRules, Optional, default = None, + pre_rules: ListRules, Optional, default = None Custom type: ``Collection[Callable[[str], str]]``. see :obj:`pytorch_widedeep.wdtypes`. Preprocessing Rules - post_rules: ListRules, Optional, default = None, + post_rules: ListRules, Optional, default = None Custom type: ``Collection[Callable[[str], str]]``. see :obj:`pytorch_widedeep.wdtypes`. Postprocessing Rules - special_cases: Collection, Optional, default= None, + special_cases: Collection, Optional, default= None special cases to be added to the tokenizer via ``Spacy``'s ``add_special_case`` method n_cpus: int, Optional, default = None diff --git a/pytorch_widedeep/utils/general_utils.py b/pytorch_widedeep/utils/general_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..695804bef3618958bb8c9893e3c37dbf51835528 --- /dev/null +++ b/pytorch_widedeep/utils/general_utils.py @@ -0,0 +1,98 @@ +""" +Code for 'Alias' and 'set_default_attr' taken from the one and only Hunter +McGushion and his library: +https://github.com/HunterMcGushion/hyperparameter_hunter +""" +from typing import Any, List, Union + +import wrapt + + +class Alias: + def __init__(self, primary_name: str, aliases: Union[str, List[str]]): + r"""Convert uses of `aliases` to `primary_name` upon calling the decorated + function/method + + Parameters + ---------- + primary_name: String + Preferred name for the parameter, the value of which will be set + to the value of the used alias. If `primary_name` is already + explicitly used on call in addition to any aliases, the value of + `primary_name` will remain unchanged. It only assumes the value of + an alias if the `primary_name` is not used + aliases: List, string + One or multiple string aliases for `primary_name`. If + `primary_name` is not used on call, its value will be set to that + of a random alias in `aliases`. Before calling the decorated + callable, all `aliases` are removed from its kwargs + + Examples + -------- + >>> class Foo(): + ... @Alias("a", ["a2"]) + ... def __init__(self, a, b=None): + ... print(a, b) + >>> @Alias("a", ["a2"]) + ... @Alias("b", ["b2"]) + ... def bar(a, b=None): + ... print(a, b) + >>> foo = Foo(a2="x", b="y") + x y + >>> bar(a2="x", b2="y") + x y""" + self.primary_name = primary_name + self.aliases = aliases if isinstance(aliases, list) else [aliases] + + @wrapt.decorator + def __call__(self, wrapped, instance, args, kwargs): + for alias in set(self.aliases).intersection(kwargs): + # Only set if no `primary_name` already. Remove `aliases`, leaving only `primary_name` + kwargs.setdefault(self.primary_name, kwargs.pop(alias)) + # Record aliases used in `instance.__wd_aliases_used` or `wrapped.__wd_aliases_used` + if instance: + set_default_attr(instance, "__wd_aliases_used", {})[ + self.primary_name + ] = alias + else: + set_default_attr(wrapped, "__wd_aliases_used", {})[ + self.primary_name + ] = alias + return wrapped(*args, **kwargs) + + +def set_default_attr(obj: Any, name: str, value: Any): + r"""Set the `name` attribute of `obj` to `value` if the attribute does not + already exist + + Parameters + ---------- + obj: Object + Object whose `name` attribute will be returned (after setting it to + `value`, if necessary) + name: String + Name of the attribute to set to `value`, or to return + value: Object + Default value to give to `obj.name` if the attribute does not already + exist + + Returns + ------- + Object + `obj.name` if it exists. Else, `value` + + Examples + -------- + >>> foo = type("Foo", tuple(), {"my_attr": 32}) + >>> set_default_attr(foo, "my_attr", 99) + 32 + >>> set_default_attr(foo, "other_attr", 9000) + 9000 + >>> assert foo.my_attr == 32 + >>> assert foo.other_attr == 9000 + """ + try: + return getattr(obj, name) + except AttributeError: + setattr(obj, name, value) + return value diff --git a/pytorch_widedeep/version.py b/pytorch_widedeep/version.py index 68cdeee4b212c9bb82f3c4319a42e06684b24c5e..39e0411d5cdae4ea426b6e3b8012478fa829e175 100644 --- a/pytorch_widedeep/version.py +++ b/pytorch_widedeep/version.py @@ -1 +1 @@ -__version__ = "1.0.5" +__version__ = "1.0.9" diff --git a/pytorch_widedeep/wdtypes.py b/pytorch_widedeep/wdtypes.py index c0d21ad9832bbdee5893a589ed5d48cdc3266cfa..996e75ca7e81cb11a443bf997a5dca83f0eaf358 100644 --- a/pytorch_widedeep/wdtypes.py +++ b/pytorch_widedeep/wdtypes.py @@ -52,7 +52,6 @@ from torch.utils.data.dataloader import DataLoader from pytorch_widedeep.models import WideDeep from pytorch_widedeep.models.tabnet.sparsemax import Entmax15, Sparsemax -from pytorch_widedeep.models.transformers.layers import FullEmbeddingDropout ListRules = Collection[Callable[[str], str]] Tokens = Collection[Collection[str]] @@ -86,4 +85,3 @@ Transforms = Union[ LRScheduler = _LRScheduler ModelParams = Generator[Tensor, Tensor, Tensor] NormLayers = Union[torch.nn.Identity, torch.nn.LayerNorm, torch.nn.BatchNorm1d] -DropoutLayers = Union[torch.nn.Dropout, FullEmbeddingDropout] diff --git a/setup.py b/setup.py index 78f81c71e6d3f13aa7bd807c01e4020d7535c6ff..ebe38678822b5957f849a00f7d4926d5e9a4624c 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ setup_kwargs = { "license": "MIT", "install_requires": [ "pandas", - "numpy", + "numpy>=1.20.0", "scipy", "scikit-learn", "gensim", @@ -67,7 +67,7 @@ setup_kwargs = { "torchmetrics", ], "extras_require": extras, - "python_requires": ">=3.6.0", + "python_requires": ">=3.7.0", "classifiers": [ dev_status[majorminor], "Environment :: Other Environment", diff --git a/tests/test_finetune/test_finetuning_routines.py b/tests/test_finetune/test_finetuning_routines.py index d5ef328db328e6913c65aa5b146484524ca82bb6..74e03bd97b15a9ef2b125e046ffc7b4e60899be4 100644 --- a/tests/test_finetune/test_finetuning_routines.py +++ b/tests/test_finetune/test_finetuning_routines.py @@ -147,7 +147,7 @@ finetuner = FineTune(loss_fn, MultipleMetrics([Accuracy()]), "binary", False) # so here we go... last_linear = list(deeptabular.children())[1] inverted_mlp_layers = list( - list(list(deeptabular.named_modules())[10][1].children())[0].children() + list(list(deeptabular.named_modules())[11][1].children())[0].children() )[::-1] tab_layers = [last_linear] + inverted_mlp_layers text_layers = [c for c in list(deeptext.children())[1:]][::-1] diff --git a/tests/test_model_components/test_mc_tab_mlp.py b/tests/test_model_components/test_mc_tab_mlp.py index 753a50e663f8d220636a5230608673afa4bbee21..624794984e2369949123af843e1b2e337010ae5d 100644 --- a/tests/test_model_components/test_mc_tab_mlp.py +++ b/tests/test_model_components/test_mc_tab_mlp.py @@ -5,6 +5,7 @@ import torch import pytest from pytorch_widedeep.models import TabMlp +from pytorch_widedeep.models.tab_mlp import CatEmbeddingsAndCont colnames = list(string.ascii_lowercase)[:10] embed_cols = [np.random.choice(np.arange(5), 10) for _ in range(5)] @@ -90,3 +91,50 @@ def test_act_fn_ValueError(): embed_input=embed_input, continuous_cols=continuous_cols, ) + + +############################################################################### +# Test CatEmbeddingsAndCont +############################################################################### + + +@pytest.mark.parametrize( + "setup, column_idx, embed_input, continuous_cols", + [ + ("w_embed", {k: v for v, k in enumerate(colnames[:5])}, embed_input, None), + ("w_cont", {k: v for v, k in enumerate(colnames[5:])}, None, continuous_cols), + ( + "w_both", + {k: v for v, k in enumerate(colnames)}, + embed_input, + continuous_cols, + ), + ], +) +def test_cat_embeddings_and_cont(setup, column_idx, embed_input, continuous_cols): + + if setup == "w_embed": + X = X_deep_emb + if setup == "w_cont": + X = X_deep_cont + if setup == "w_both": + X = X_deep + + cat_embed_and_cont = CatEmbeddingsAndCont( + column_idx, embed_input, 0.1, continuous_cols, None + ) + x_cat, x_cont = cat_embed_and_cont(X) + + if setup == "w_embed": + assert ( + x_cat.size() == torch.Size([X.shape[0], len(column_idx) * 16]) + and x_cont is None + ) + if setup == "w_cont": + assert ( + x_cont.size() == torch.Size([X.shape[0], len(column_idx)]) and x_cat is None + ) + if setup == "w_both": + assert x_cat.size() == torch.Size( + [X.shape[0], len(embed_cols) * 16] + ) and x_cont.size() == torch.Size([X.shape[0], len(cont_cols)]) diff --git a/tests/test_model_components/test_mc_tab_tabnet.py b/tests/test_model_components/test_mc_tab_tabnet.py index 42b3cebd34c57541d10f16b9a900578cc273d457..e11f9d4c5985928e1777c3b65ede0f3eef117dd6 100644 --- a/tests/test_model_components/test_mc_tab_tabnet.py +++ b/tests/test_model_components/test_mc_tab_tabnet.py @@ -5,8 +5,8 @@ import torch import pytest from pytorch_widedeep.wdtypes import WideDeep +from pytorch_widedeep.models.tabnet._utils import create_explain_matrix from pytorch_widedeep.models.tabnet.tab_net import TabNet # noqa: F403 -from pytorch_widedeep.models.tabnet.tab_net_utils import create_explain_matrix # I am going over test this model due to the number of components @@ -38,7 +38,7 @@ model1 = TabNet( def test_embeddings_have_padding(): res = [] - for k, v in model1.embed_and_cont.embed_layers.items(): + for k, v in model1.cat_embed_and_cont.embed_layers.items(): res.append(v.weight.size(0) == n_embed + 1) res.append(not torch.all(v.weight[0].bool())) assert all(res) diff --git a/tests/test_model_components/test_mc_transformers.py b/tests/test_model_components/test_mc_transformers.py index 7c08a30fa28a69dbd58a37276e64687b03fb59c5..3c5bda959c604fe387aa99f319151e10e909bdf9 100644 --- a/tests/test_model_components/test_mc_transformers.py +++ b/tests/test_model_components/test_mc_transformers.py @@ -5,8 +5,15 @@ import numpy as np import torch import pytest -from pytorch_widedeep.models import SAINT, TabTransformer -from pytorch_widedeep.models.transformers.layers import * # noqa: F403 +from pytorch_widedeep.models import ( + SAINT, + TabPerceiver, + FTTransformer, + TabFastFormer, + TabTransformer, +) +from pytorch_widedeep.models.transformers._attention_layers import * # noqa: F403 +from pytorch_widedeep.models.transformers._embeddings_layers import * # noqa: F403 # I am going over test these models due to the number of components @@ -15,12 +22,12 @@ n_cols = 2 batch_size = 10 colnames = list(string.ascii_lowercase)[: (n_cols * 2)] embed_cols = [np.random.choice(np.arange(n_embed), batch_size) for _ in range(n_cols)] -embed_cols_with_cls_token = [[n_embed] * batch_size] + embed_cols +embed_cols_with_cls_token = [[n_embed] * batch_size] + embed_cols # type: ignore[operator] cont_cols = [np.random.rand(batch_size) for _ in range(n_cols)] X_tab = torch.from_numpy(np.vstack(embed_cols + cont_cols).transpose()) X_tab_with_cls_token = torch.from_numpy( - np.vstack(embed_cols_with_cls_token + cont_cols).transpose() + np.vstack(embed_cols_with_cls_token + cont_cols).transpose() # type: ignore[operator] ) @@ -38,8 +45,13 @@ model1 = TabTransformer( def test_embeddings_have_padding(): res = [] - res.append(model1.cat_embed.weight.size(0) == model1.n_tokens + 1) - res.append(not torch.all(model1.cat_embed.weight[0].bool())) + res.append( + model1.cat_and_cont_embed.cat_embed.embed.weight.size(0) + == model1.cat_and_cont_embed.cat_embed.n_tokens + 1 + ) + res.append( + not torch.all(model1.cat_and_cont_embed.cat_embed.embed.weight[0].bool()) + ) assert all(res) @@ -96,7 +108,7 @@ model2 = TabTransformer( def test_shared_embeddings_have_padding(): res = [] - for k, v in model2.cat_embed.items(): + for k, v in model2.cat_and_cont_embed.cat_embed.embed.items(): res.append(v.embed.weight.size(0) == n_embed + 1) res.append(not torch.all(v.embed.weight[0].bool())) assert all(res) @@ -120,7 +132,7 @@ def test_continuous_embeddings(): X = torch.rand(bsz, n_cont_cols) cont_embed = ContinuousEmbeddings( - n_cont_cols=n_cont_cols, embed_dim=embed_dim, activation=None, bias=None + n_cont_cols=n_cont_cols, embed_dim=embed_dim, activation=None, use_bias=False ) out = cont_embed(X) res = ( @@ -157,18 +169,31 @@ def test_full_embed_dropout(): bsz = 1 cat = 10 esz = 4 - full_embedding_dropout = FullEmbeddingDropout(dropout=0.5) + full_embedding_dropout = FullEmbeddingDropout(dropout=0.8) inp = torch.rand(bsz, cat, esz) out = full_embedding_dropout(inp) # simply check that at least 1 full row is all 0s - assert torch.any(torch.sum(out[0] == 0, axis=1) == esz) + assert (torch.sum(out[0] == 0, axis=1) == esz).sum() > 0 # ############################################################################### -# # Beginning of a 360 test of SAINT and TabTransformer +# # Beginning of a 360 test of the Transformer family # ############################################################################### +def _build_model(model_name, params): + if model_name == "tabtransformer": + return TabTransformer(n_blocks=2, n_heads=2, **params) + if model_name == "saint": + return SAINT(n_blocks=2, n_heads=2, **params) + if model_name == "fttransformer": + return FTTransformer(n_blocks=2, n_heads=2, kv_compression_factor=0.5, **params) + if model_name == "tabfastformer": + return TabFastFormer(n_blocks=2, n_heads=2, **params) + if model_name == "tabperceiver": + return TabPerceiver(n_perceiver_blocks=2, n_latents=2, latent_dim=16, **params) + + @pytest.mark.parametrize( "embed_continuous, with_cls_token, model_name", [ @@ -176,53 +201,94 @@ def test_full_embed_dropout(): (True, False, "tabtransformer"), (False, True, "tabtransformer"), (False, False, "tabtransformer"), - (True, True, "saint"), - (True, False, "saint"), - (False, True, "saint"), - (False, False, "saint"), ], ) -def test_embed_continuous_and_with_cls_token( +def test_embed_continuous_and_with_cls_token_tabtransformer( embed_continuous, with_cls_token, model_name ): if with_cls_token: X = X_tab_with_cls_token n_colnames = ["cls_token"] + copy(colnames) cont_idx = n_cols + 1 + with_cls_token_embed_input = [("cls_token", 1)] + embed_input else: X = X_tab n_colnames = copy(colnames) cont_idx = n_cols - if model_name == "tabtransformer": - model = TabTransformer( - column_idx={k: v for v, k in enumerate(n_colnames)}, - embed_input=embed_input, - continuous_cols=n_colnames[cont_idx:], - embed_continuous=embed_continuous, - ) - elif model_name == "saint": - model = SAINT( - column_idx={k: v for v, k in enumerate(n_colnames)}, - embed_input=embed_input, - continuous_cols=n_colnames[cont_idx:], - embed_continuous=embed_continuous, - ) + params = { + "column_idx": {k: v for v, k in enumerate(n_colnames)}, + "embed_input": with_cls_token_embed_input if with_cls_token else embed_input, + "continuous_cols": n_colnames[cont_idx:], + "embed_continuous": embed_continuous, + } + + model = _build_model(model_name, params) + out = model(X) res = [out.size(0) == 10] if with_cls_token: if embed_continuous: - res.append(model._set_mlp_hidden_dims()[0] == model.input_dim) + res.append(model._compute_attn_output_dim() == model.input_dim) else: res.append( - model._set_mlp_hidden_dims()[0] == model.input_dim + len(cont_cols) + model._compute_attn_output_dim() == model.input_dim + len(cont_cols) ) elif embed_continuous: mlp_first_h = X.shape[1] * model.input_dim - res.append(model._set_mlp_hidden_dims()[0] == mlp_first_h) + res.append(model._compute_attn_output_dim() == mlp_first_h) else: mlp_first_h = len(embed_cols) * model.input_dim + 2 - res.append(model._set_mlp_hidden_dims()[0] == mlp_first_h) + res.append(model._compute_attn_output_dim() == mlp_first_h) + + assert all(res) + + +@pytest.mark.parametrize( + "with_cls_token, model_name", + [ + (True, "saint"), + (False, "saint"), + (True, "fttransformer"), + (False, "fttransformer"), + (True, "tabfastformer"), + (False, "tabfastformer"), + ], +) +def test_embed_continuous_and_with_cls_token_transformer_family( + with_cls_token, model_name +): + if with_cls_token: + X = X_tab_with_cls_token + n_colnames = ["cls_token"] + copy(colnames) + cont_idx = n_cols + 1 + with_cls_token_embed_input = [("cls_token", 1)] + embed_input + else: + X = X_tab + n_colnames = copy(colnames) + cont_idx = n_cols + + params = { + "column_idx": {k: v for v, k in enumerate(n_colnames)}, + "embed_input": with_cls_token_embed_input if with_cls_token else embed_input, + "continuous_cols": n_colnames[cont_idx:], + } + + total_n_cols = n_cols * 2 + model = _build_model(model_name, params) + + out = model(X) + res = [out.size(0) == 10] + if with_cls_token: + if model_name in ["saint", "tabfastformer"]: + res.append(out.shape[1] == model.input_dim * 2) + elif model_name == "fttransformer": + res.append(out.shape[1] == model.input_dim) + else: + if model_name in ["saint", "tabfastformer"]: + res.append(out.shape[1] == (total_n_cols * model.input_dim) * 2) + elif model_name == "fttransformer": + res.append(out.shape[1] == (total_n_cols * model.input_dim)) assert all(res) @@ -230,32 +296,39 @@ def test_embed_continuous_and_with_cls_token( @pytest.mark.parametrize( "activation, model_name", [ - ("relu", "tabtransformer"), + ("tanh", "tabtransformer"), ("leaky_relu", "tabtransformer"), - ("gelu", "tabtransformer"), ("geglu", "tabtransformer"), - ("relu", "saint"), + ("reglu", "tabtransformer"), + ("tanh", "saint"), ("leaky_relu", "saint"), - ("gelu", "saint"), ("geglu", "saint"), + ("reglu", "saint"), + ("tanh", "fttransformer"), + ("leaky_relu", "fttransformer"), + ("geglu", "fttransformer"), + ("reglu", "fttransformer"), + ("tanh", "tabfastformer"), + ("leaky_relu", "tabfastformer"), + ("geglu", "tabfastformer"), + ("reglu", "tabfastformer"), + ("tanh", "tabperceiver"), + ("leaky_relu", "tabperceiver"), + ("geglu", "tabperceiver"), + ("reglu", "tabperceiver"), ], ) def test_transformer_activations(activation, model_name): - if model_name == "tabtransformer": - model = TabTransformer( - column_idx={k: v for v, k in enumerate(colnames)}, - embed_input=embed_input, - continuous_cols=colnames[n_cols:], - transformer_activation=activation, - ) - elif model_name == "saint": - model = SAINT( - column_idx={k: v for v, k in enumerate(colnames)}, - embed_input=embed_input, - continuous_cols=colnames[n_cols:], - transformer_activation=activation, - ) + params = { + "column_idx": {k: v for v, k in enumerate(colnames)}, + "embed_input": embed_input, + "continuous_cols": colnames[n_cols:], + "transformer_activation": activation, + } + + model = _build_model(model_name, params) + out = model(X_tab) assert out.size(0) == 10 @@ -270,30 +343,33 @@ def test_transformer_activations(activation, model_name): [ "tabtransformer", "saint", + "fttransformer", + "tabfastformer", + "tabperceiver", ], ) -def test_tabtransformer_keep_attn(model_name): - if model_name == "tabtransformer": - model = TabTransformer( - column_idx={k: v for v, k in enumerate(colnames)}, - embed_input=embed_input, - continuous_cols=colnames[n_cols:], - n_blocks=4, - keep_attn_weights=True, - ) - elif model_name == "saint": - model = SAINT( - column_idx={k: v for v, k in enumerate(colnames)}, - embed_input=embed_input, - continuous_cols=colnames[n_cols:], - n_blocks=4, - keep_attn_weights=True, - ) +def test_transformers_keep_attn(model_name): + + params = { + "column_idx": {k: v for v, k in enumerate(colnames)}, + "embed_input": embed_input, + "continuous_cols": colnames[n_cols:], + } + + # n_cols is an unfortunate name I might change in the future. It refers to + # the number of cat and cont cols, so the total number of cols is + # n_cols * 2 + total_n_cols = n_cols * 2 + + model = _build_model(model_name, params) + out = model(X_tab) res = [out.size(0) == 10] - res.append(out.size(1) == model._set_mlp_hidden_dims()[-1]) - res.append(len(model.attention_weights) == model.n_blocks) + if model_name != "tabperceiver": + res.append(len(model.attention_weights) == model.n_blocks) + else: + res.append(len(model.attention_weights) == model.n_perceiver_blocks) if model_name == "tabtransformer": res.append( @@ -303,10 +379,73 @@ def test_tabtransformer_keep_attn(model_name): elif model_name == "saint": res.append( list(model.attention_weights[0][0].shape) - == [10, model.n_heads, n_cols, n_cols] + == [10, model.n_heads, total_n_cols, total_n_cols] + ) + res.append( + list(model.attention_weights[0][1].shape) + == [1, model.n_heads, X_tab.shape[0], X_tab.shape[0]] + ) + if model_name == "fttransformer": + res.append( + list(model.attention_weights[0].shape) + == [ + 10, + model.n_heads, + total_n_cols, + int(model.n_feats * model.kv_compression_factor), + ] + ) + elif model_name == "tabperceiver": + res.append( + len(model.attention_weights[0]) + == model.n_cross_attns + model.n_latent_blocks + ) + res.append( + list(model.attention_weights[0][0].shape) + == [10, model.n_cross_attn_heads, model.n_latents, X_tab.shape[1]] ) res.append( list(model.attention_weights[0][1].shape) - == [1, model.n_heads, n_cols * n_embed, n_cols * n_embed] + == [10, model.n_cross_attn_heads, model.n_latents, model.n_latents] + ) + elif model_name == "tabfastformer": + res.append( + list(model.attention_weights[0][0].shape) + == [10, model.n_heads, total_n_cols] + ) + res.append( + list(model.attention_weights[0][1].shape) + == [10, model.n_heads, total_n_cols] ) assert all(res) + + +############################################################################### +# Test FTTransformer mlp set up +############################################################################### + + +@pytest.mark.parametrize( + "mlp_first_h, shoud_work", + [ + ((n_cols * 2) * 64, True), + ((n_cols * 2) * (64 + 1), False), + ], +) +def test_ft_transformer_mlp(mlp_first_h, shoud_work): + + mlp_hidden_dims = [mlp_first_h, mlp_first_h * 2] + + params = { + "column_idx": {k: v for v, k in enumerate(colnames)}, + "embed_input": embed_input, + "continuous_cols": colnames[n_cols:], + "mlp_hidden_dims": mlp_hidden_dims, + } + + if shoud_work: + model = _build_model("fttransformer", params) + assert True + else: + with pytest.raises(AssertionError): + model = _build_model("fttransformer", params) # noqa: F841 diff --git a/tests/test_model_functioning/test_miscellaneous.py b/tests/test_model_functioning/test_miscellaneous.py index f8ea0b8a1b7d28298484fa3cedf302fcdb61636c..3837db1ccb4c87e378b663fcec45518b5bd1a960 100644 --- a/tests/test_model_functioning/test_miscellaneous.py +++ b/tests/test_model_functioning/test_miscellaneous.py @@ -5,6 +5,7 @@ from copy import deepcopy import numpy as np import torch +import pandas as pd import pytest from sklearn.model_selection import train_test_split @@ -21,6 +22,7 @@ from pytorch_widedeep.models import ( from pytorch_widedeep.metrics import Accuracy, Precision from pytorch_widedeep.training import Trainer from pytorch_widedeep.callbacks import EarlyStopping +from pytorch_widedeep.preprocessing import TabPreprocessor # Wide array X_wide = np.random.choice(50, (32, 10)) @@ -342,3 +344,61 @@ def test_save_load_and_predict(): shutil.rmtree(fpath) assert preds.shape[0] == X_tab.shape[0] + + +############################################################################### +# test get_embeddings DeprecationWarning +############################################################################### + + +def create_test_dataset(input_type, input_type_2=None): + df = pd.DataFrame() + col1 = list(np.random.choice(input_type, 32)) + if input_type_2 is not None: + col2 = list(np.random.choice(input_type_2, 32)) + else: + col2 = list(np.random.choice(input_type, 32)) + df["col1"], df["col2"] = col1, col2 + return df + + +some_letters = ["a", "b", "c", "d", "e"] + +df = create_test_dataset(some_letters) +df["col3"] = np.round(np.random.rand(32), 3) +df["col4"] = np.round(np.random.rand(32), 3) +df["target"] = np.random.choice(2, 32) + + +def test_get_embeddings_deprecation_warning(): + + embed_cols = [("col1", 5), ("col2", 5)] + continuous_cols = ["col3", "col4"] + + tab_preprocessor = TabPreprocessor( + embed_cols=embed_cols, continuous_cols=continuous_cols + ) + X_tab = tab_preprocessor.fit_transform(df) + target = df.target.values + + tabmlp = TabMlp( + mlp_hidden_dims=[32, 16], + mlp_dropout=[0.5, 0.5], + column_idx={k: v for v, k in enumerate(df.columns)}, + embed_input=tab_preprocessor.embeddings_input, + continuous_cols=tab_preprocessor.continuous_cols, + ) + + model = WideDeep(deeptabular=tabmlp) + trainer = Trainer(model, objective="binary", verbose=0) + trainer.fit( + X_tab=X_tab, + target=target, + batch_size=16, + ) + + with pytest.warns(DeprecationWarning): + trainer.get_embeddings( + col_name="col1", + cat_encoding_dict=tab_preprocessor.label_encoder.encoding_dict, + ) diff --git a/tests/test_tab2vec/test_t2v.py b/tests/test_tab2vec/test_t2v.py new file mode 100644 index 0000000000000000000000000000000000000000..c2468fa8d9ef0082369f9da0cd4ec36d52b8695e --- /dev/null +++ b/tests/test_tab2vec/test_t2v.py @@ -0,0 +1,205 @@ +import string +from random import choices + +import numpy as np +import pandas as pd +import pytest + +from pytorch_widedeep import Tab2Vec +from pytorch_widedeep.models import ( + SAINT, + TabMlp, + TabNet, + WideDeep, + TabResnet, + TabPerceiver, + FTTransformer, + TabFastFormer, + TabTransformer, +) +from pytorch_widedeep.preprocessing import TabPreprocessor + +colnames = list(string.ascii_lowercase)[:4] + ["target"] +cat_col1_vals = ["a", "b", "c"] +cat_col2_vals = ["d", "e", "f"] + + +def create_df(): + cat_cols = [np.array(choices(c, k=5)) for c in [cat_col1_vals, cat_col2_vals]] + cont_cols = [np.round(np.random.rand(5), 2) for _ in range(2)] + target = [np.random.choice(2, 5)] + return pd.DataFrame( + np.vstack(cat_cols + cont_cols + target).transpose(), columns=colnames + ) + + +df_init = create_df() +df_t2v = create_df() + +embed_cols = [("a", 2), ("b", 4)] +cont_cols = ["c", "d"] +tab_preprocessor = TabPreprocessor(embed_cols=embed_cols, continuous_cols=cont_cols) +X_tab = tab_preprocessor.fit_transform(df_init) + +tabmlp = TabMlp( + column_idx=tab_preprocessor.column_idx, + embed_input=tab_preprocessor.embeddings_input, + continuous_cols=tab_preprocessor.continuous_cols, + mlp_hidden_dims=[8, 4], +) + +tabresnet = TabResnet( + column_idx=tab_preprocessor.column_idx, + embed_input=tab_preprocessor.embeddings_input, + continuous_cols=tab_preprocessor.continuous_cols, + blocks_dims=[8, 8, 4], +) + +tabnet = TabNet( + column_idx=tab_preprocessor.column_idx, + embed_input=tab_preprocessor.embeddings_input, + continuous_cols=tab_preprocessor.continuous_cols, +) + + +@pytest.mark.parametrize( + "deeptabular", + [tabmlp, tabresnet, tabnet], +) +def test_non_transformer_models(deeptabular): + + model = WideDeep(deeptabular=deeptabular) + + # Let's assume the model is trained + t2v = Tab2Vec(model, tab_preprocessor) + X_vec, _ = t2v.fit_transform(df_t2v, target_col="target") + + embed_dim = sum([el[2] for el in tab_preprocessor.embeddings_input]) + cont_dim = len(tab_preprocessor.continuous_cols) + assert X_vec.shape[1] == embed_dim + cont_dim + + +############################################################################### +# Test Transformer models +############################################################################### + + +def _build_model(model_name, params): + if model_name == "tabtransformer": + return TabTransformer(input_dim=8, n_heads=2, n_blocks=2, **params) + if model_name == "saint": + return SAINT(input_dim=8, n_heads=2, n_blocks=2, **params) + if model_name == "fttransformer": + return FTTransformer(n_blocks=2, n_heads=2, kv_compression_factor=0.5, **params) + if model_name == "tabfastformer": + return TabFastFormer(n_blocks=2, n_heads=2, **params) + if model_name == "tabperceiver": + return TabPerceiver( + input_dim=8, + n_cross_attn_heads=2, + n_latents=2, + latent_dim=8, + n_latent_heads=2, + n_perceiver_blocks=2, + share_weights=False, + **params + ) + + +@pytest.mark.parametrize( + "model_name, with_cls_token, share_embeddings, embed_continuous", + [ + ("tabtransformer", False, False, False), + ("tabtransformer", True, False, False), + ("tabtransformer", False, True, False), + ("tabtransformer", True, False, True), + ], +) +def test_tab_transformer_models( + model_name, with_cls_token, share_embeddings, embed_continuous +): + + embed_cols = ["a", "b"] + cont_cols = ["c", "d"] + + tab_preprocessor = TabPreprocessor( + embed_cols=embed_cols, + continuous_cols=cont_cols, + for_transformer=True, + with_cls_token=with_cls_token, + shared_embed=share_embeddings, + ) + X_tab = tab_preprocessor.fit_transform(df_init) # noqa: F841 + + params = { + "column_idx": tab_preprocessor.column_idx, + "embed_input": tab_preprocessor.embeddings_input, + "continuous_cols": tab_preprocessor.continuous_cols, + "embed_continuous": embed_continuous, + } + + deeptabular = _build_model(model_name, params) + + # Let's assume the model is trained + model = WideDeep(deeptabular=deeptabular) + t2v = Tab2Vec(model, tab_preprocessor) + X_vec = t2v.transform(df_t2v) + + if embed_continuous: + out_dim = (len(embed_cols) + len(cont_cols)) * deeptabular.input_dim + else: + out_dim = len(embed_cols) * deeptabular.input_dim + len(cont_cols) + + assert X_vec.shape[1] == out_dim + + +@pytest.mark.parametrize( + "model_name, with_cls_token, share_embeddings", + [ + ("saint", False, True), + ("saint", True, True), + ("saint", False, False), + ("fttransformer", False, True), + ("fttransformer", True, True), + ("fttransformer", False, False), + ("tabfastformer", False, True), + ("tabfastformer", True, True), + ("tabfastformer", False, False), + ( + "tabperceiver", + False, + True, + ), # for the perceiver we do not need with_cls_token + ("tabperceiver", False, False), + ], +) +def test_transformer_family_models(model_name, with_cls_token, share_embeddings): + + embed_cols = ["a", "b"] + cont_cols = ["c", "d"] + + tab_preprocessor = TabPreprocessor( + embed_cols=embed_cols, + continuous_cols=cont_cols, + for_transformer=True, + with_cls_token=with_cls_token, + shared_embed=share_embeddings, + ) + X_tab = tab_preprocessor.fit_transform(df_init) # noqa: F841 + + params = { + "column_idx": tab_preprocessor.column_idx, + "embed_input": tab_preprocessor.embeddings_input, + "continuous_cols": tab_preprocessor.continuous_cols, + } + + deeptabular = _build_model(model_name, params) + + # Let's assume the model is trained + model = WideDeep(deeptabular=deeptabular) + t2v = Tab2Vec(model, tab_preprocessor) + X_vec = t2v.transform(df_t2v) + + out_dim = (len(embed_cols) + len(cont_cols)) * deeptabular.input_dim + + assert X_vec.shape[1] == out_dim