未验证 提交 0c79deb5 编写于 作者: J Javier 提交者: GitHub

Merge pull request #47 from jrzaurin/saint

Saint
...@@ -22,6 +22,8 @@ using wide and deep models. ...@@ -22,6 +22,8 @@ using wide and deep models.
**Experiments and comparisson with `LightGBM`**: [TabularDL vs LightGBM](https://github.com/jrzaurin/tabulardl-benchmark) **Experiments and comparisson with `LightGBM`**: [TabularDL vs LightGBM](https://github.com/jrzaurin/tabulardl-benchmark)
**Slack**: if you want to contribute or just want to chat with us, join [slack](https://join.slack.com/t/pytorch-widedeep/shared_invite/zt-soss7stf-iXpVuLeKZz8lGTnxxtHtTw)
### Introduction ### Introduction
``pytorch-widedeep`` is based on Google's [Wide and Deep Algorithm](https://arxiv.org/abs/1606.07792) ``pytorch-widedeep`` is based on Google's [Wide and Deep Algorithm](https://arxiv.org/abs/1606.07792)
...@@ -82,10 +84,11 @@ into: ...@@ -82,10 +84,11 @@ into:
It is important to emphasize that **each individual component, `wide`, It is important to emphasize that **each individual component, `wide`,
`deeptabular`, `deeptext` and `deepimage`, can be used independently** and in `deeptabular`, `deeptext` and `deepimage`, can be used independently** and in
isolation. For example, one could use only `wide`, which is in simply a linear isolation. For example, one could use only `wide`, which is in simply a
model. In fact, one of the most interesting functionalities linear model. In fact, one of the most interesting functionalities
in``pytorch-widedeep`` is the ``deeptabular`` component. Currently, in``pytorch-widedeep`` is the ``deeptabular`` component. Currently,
``pytorch-widedeep`` offers 4 models for that component: ``pytorch-widedeep`` offers the following different models for that
component:
1. ``TabMlp``: this is almost identical to the [tabular 1. ``TabMlp``: this is almost identical to the [tabular
model](https://docs.fast.ai/tutorial.tabular.html) in the fantastic model](https://docs.fast.ai/tutorial.tabular.html) in the fantastic
...@@ -100,11 +103,26 @@ passed through a series of ResNet blocks built with dense layers. ...@@ -100,11 +103,26 @@ passed through a series of ResNet blocks built with dense layers.
[TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442) [TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442)
4. ``TabTransformer``: Details on the TabTransformer can be found in: 4. ``TabTransformer``: Details on the TabTransformer can be found in:
[TabTransformer: Tabular Data Modeling Using Contextual [TabTransformer: Tabular Data Modeling Using Contextual Embeddings](https://arxiv.org/pdf/2012.06678.pdf).
Embeddings](https://arxiv.org/pdf/2012.06678.pdf) Note that the TabTransformer implementation available at ``pytorch-widedeep``
is an adaptation of the original implementation.
For details on these 4 models and their options please see the examples in the 5. ``FT-Transformer``: or Feature Tokenizer transformer. This is a relatively small
variation of the ``TabTransformer``. The variation itself was first
introduced in the ``SAINT`` paper, but the name "``FT-Transformer``" was first
used in
[Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/abs/2106.11959).
When using the ``FT-Transformer`` each continuous feature is "embedded"
(i.e. going through a 1-layer MLP with or without activation function) and
then passed through the attention blocks along with the categorical features.
This is available in ``pytorch-widedeep``'s ``TabTransformer`` by setting the
parameter ``embed_continuous = True``.
6. ``SAINT``: Details on SAINT can be found in:
[SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training](https://arxiv.org/abs/2106.01342).
For details on these models and their options please see the examples in the
Examples folder and the documentation. Examples folder and the documentation.
Finally, while I recommend using the ``wide`` and ``deeptabular`` models in Finally, while I recommend using the ``wide`` and ``deeptabular`` models in
...@@ -143,20 +161,19 @@ cd pytorch-widedeep ...@@ -143,20 +161,19 @@ cd pytorch-widedeep
pip install -e . pip install -e .
``` ```
**Important note for Mac users**: at the time of writing (June-2021) the **Important note for Mac users**: at the time of writing the latest `torch`
latest `torch` release is `1.9`. Some past release is `1.9`. Some past [issues](https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206)
[issues](https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206) when running on Mac, present in previous versions, persist on this release
when running on Mac, present in previous versions, persist on this release and and the data-loaders will not run in parallel. In addition, since `python
the data-loaders will not run in parallel. In addition, since `python 3.8`, 3.8`, [the `multiprocessing` library start method changed from `'fork'` to`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).
[the `multiprocessing` library start method changed from `'fork'` to
`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).
This also affects the data-loaders (for any `torch` version) and they will This also affects the data-loaders (for any `torch` version) and they will
not run in parallel. Therefore, for Mac users I recommend using `python 3.6` not run in parallel. Therefore, for Mac users I recommend using `python
or `3.7` and `torch <= 1.6` (with the corresponding, consistent version of 3.6` or `3.7` and `torch <= 1.6` (with the corresponding, consistent
`torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to force this version of `torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to
versioning in the `setup.py` file since I expect that all these issues are force this versioning in the `setup.py` file since I expect that all these
fixed in the future. Therefore, after installing `pytorch-widedeep` via pip issues are fixed in the future. Therefore, after installing
or directly from github, downgrade `torch` and `torchvision` manually: `pytorch-widedeep` via pip or directly from github, downgrade `torch` and
`torchvision` manually:
```bash ```bash
pip install pytorch-widedeep pip install pytorch-widedeep
......
1.0.0 1.0.5
\ No newline at end of file \ No newline at end of file
...@@ -15,3 +15,4 @@ them to address different problems ...@@ -15,3 +15,4 @@ them to address different problems
* `Custom Components <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/07_Custom_Components.ipynb>`__ * `Custom Components <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/07_Custom_Components.ipynb>`__
* `Save and Load Model and Artifacts <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/08_save_and_load_model_and_artifacts.ipynb>`__ * `Save and Load Model and Artifacts <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/08_save_and_load_model_and_artifacts.ipynb>`__
* `Using Custom DataLoaders and Torchmetrics <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/09_Custom_DataLoader_Imbalanced_dataset.ipynb>`__ * `Using Custom DataLoaders and Torchmetrics <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/09_Custom_DataLoader_Imbalanced_dataset.ipynb>`__
* `The Transformer Family <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/10_The_Transformer_Family.ipynb>`__
docs/figures/tabmlp_arch.png

46.0 KB | W: | H:

docs/figures/tabmlp_arch.png

47.7 KB | W: | H:

docs/figures/tabmlp_arch.png
docs/figures/tabmlp_arch.png
docs/figures/tabmlp_arch.png
docs/figures/tabmlp_arch.png
  • 2-up
  • Swipe
  • Onion skin
docs/figures/tabresnet_arch.png

78.6 KB | W: | H:

docs/figures/tabresnet_arch.png

81.1 KB | W: | H:

docs/figures/tabresnet_arch.png
docs/figures/tabresnet_arch.png
docs/figures/tabresnet_arch.png
docs/figures/tabresnet_arch.png
  • 2-up
  • Swipe
  • Onion skin
docs/figures/tabtransformer_arch.png

63.1 KB | W: | H:

docs/figures/tabtransformer_arch.png

63.7 KB | W: | H:

docs/figures/tabtransformer_arch.png
docs/figures/tabtransformer_arch.png
docs/figures/tabtransformer_arch.png
docs/figures/tabtransformer_arch.png
  • 2-up
  • Swipe
  • Onion skin
...@@ -5,9 +5,9 @@ This module contains the four main components that will comprise a Wide and ...@@ -5,9 +5,9 @@ This module contains the four main components that will comprise a Wide and
Deep model, and the ``WideDeep`` "constructor" class. These four components Deep model, and the ``WideDeep`` "constructor" class. These four components
are: ``wide``, ``deeptabular``, ``deeptext``, ``deepimage``. are: ``wide``, ``deeptabular``, ``deeptext``, ``deepimage``.
.. note:: ``TabMlp``, ``TabResnet``, ``TabNet`` and ``TabTransformer`` can all .. note:: ``TabMlp``, ``TabResnet``, ``TabNet``, ``TabTransformer`` and ``SAINT`` can
be used as the ``deeptabular`` component of the model and simply all be used as the ``deeptabular`` component of the model and simply
represent different alternatives represent different alternatives
.. autoclass:: pytorch_widedeep.models.wide.Wide .. autoclass:: pytorch_widedeep.models.wide.Wide
:exclude-members: forward :exclude-members: forward
...@@ -25,7 +25,11 @@ are: ``wide``, ``deeptabular``, ``deeptext``, ``deepimage``. ...@@ -25,7 +25,11 @@ are: ``wide``, ``deeptabular``, ``deeptext``, ``deepimage``.
:exclude-members: forward :exclude-members: forward
:members: :members:
.. autoclass:: pytorch_widedeep.models.tab_transformer.TabTransformer .. autoclass:: pytorch_widedeep.models.transformers.tab_transformer.TabTransformer
:exclude-members: forward
:members:
.. autoclass:: pytorch_widedeep.models.transformers.saint.SAINT
:exclude-members: forward :exclude-members: forward
:members: :members:
......
...@@ -181,7 +181,7 @@ ...@@ -181,7 +181,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"tensor([0.1001], grad_fn=<AddBackward0>)" "tensor([-0.2856], grad_fn=<AddBackward0>)"
] ]
}, },
"execution_count": 7, "execution_count": 7,
...@@ -201,7 +201,7 @@ ...@@ -201,7 +201,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"tensor([0.1001], grad_fn=<AddBackward0>)" "tensor([-0.2856], grad_fn=<AddBackward0>)"
] ]
}, },
"execution_count": 8, "execution_count": 8,
...@@ -222,18 +222,9 @@ ...@@ -222,18 +222,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/javier/.pyenv/versions/3.7.7/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n"
]
}
],
"source": [ "source": [
"from pytorch_widedeep.models import Wide" "from pytorch_widedeep.models import Wide"
] ]
...@@ -313,7 +304,7 @@ ...@@ -313,7 +304,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 14,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -322,7 +313,7 @@ ...@@ -322,7 +313,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 15,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -340,7 +331,7 @@ ...@@ -340,7 +331,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 16,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -352,7 +343,7 @@ ...@@ -352,7 +343,7 @@
")" ")"
] ]
}, },
"execution_count": 15, "execution_count": 16,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -370,7 +361,7 @@ ...@@ -370,7 +361,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 17,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -379,7 +370,7 @@ ...@@ -379,7 +370,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 18,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -399,7 +390,7 @@ ...@@ -399,7 +390,7 @@
")" ")"
] ]
}, },
"execution_count": 17, "execution_count": 18,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -426,7 +417,7 @@ ...@@ -426,7 +417,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 19,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -444,7 +435,7 @@ ...@@ -444,7 +435,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 21,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -453,7 +444,7 @@ ...@@ -453,7 +444,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": 22,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -462,7 +453,7 @@ ...@@ -462,7 +453,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 21, "execution_count": 23,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -569,7 +560,7 @@ ...@@ -569,7 +560,7 @@
")" ")"
] ]
}, },
"execution_count": 21, "execution_count": 23,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -580,18 +571,19 @@ ...@@ -580,18 +571,19 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22, "execution_count": 24,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"tensor([[-0.0014, 0.5209, -0.0037, 0.1040, -0.0010, -0.0026, -0.0037, 0.1090],\n", "tensor([[-1.0338e-03, 5.0809e-01, -5.1775e-04, 2.8709e-01, -5.5744e-03,\n",
" [ 0.0227, 0.4762, -0.0038, 0.5660, -0.0033, -0.0043, 0.2624, 0.0369]],\n", " -5.9626e-03, 1.2294e-01, 1.6768e-01],\n",
" grad_fn=<LeakyReluBackward1>)" " [-1.1770e-03, 3.8934e-02, -2.4541e-03, 6.6003e-03, -4.3299e-03,\n",
" -5.0524e-03, 4.7879e-03, -3.5898e-04]], grad_fn=<LeakyReluBackward1>)"
] ]
}, },
"execution_count": 22, "execution_count": 24,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -609,7 +601,7 @@ ...@@ -609,7 +601,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 23, "execution_count": 25,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -618,7 +610,7 @@ ...@@ -618,7 +610,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 24, "execution_count": 26,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -666,7 +658,7 @@ ...@@ -666,7 +658,7 @@
")" ")"
] ]
}, },
"execution_count": 24, "execution_count": 26,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -690,7 +682,7 @@ ...@@ -690,7 +682,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 25, "execution_count": 27,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
......
...@@ -519,6 +519,7 @@ ...@@ -519,6 +519,7 @@
" (emb_layer_workclass): Embedding(10, 16, padding_idx=0)\n", " (emb_layer_workclass): Embedding(10, 16, padding_idx=0)\n",
" )\n", " )\n",
" (embedding_dropout): Dropout(p=0.1, inplace=False)\n", " (embedding_dropout): Dropout(p=0.1, inplace=False)\n",
" (cont_norm): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (tab_mlp): MLP(\n", " (tab_mlp): MLP(\n",
" (mlp): Sequential(\n", " (mlp): Sequential(\n",
" (dense_layer_0): Sequential(\n", " (dense_layer_0): Sequential(\n",
...@@ -588,16 +589,16 @@ ...@@ -588,16 +589,16 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 611/611 [00:06<00:00, 91.05it/s, loss=0.405, metrics={'acc': 0.8076, 'prec': 0.6255}] \n", "epoch 1: 100%|██████████| 611/611 [00:07<00:00, 83.83it/s, loss=0.425, metrics={'acc': 0.801, 'prec': 0.6074}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 155.18it/s, loss=0.36, metrics={'acc': 0.8346, 'prec': 0.6752}] \n", "valid: 100%|██████████| 153/153 [00:01<00:00, 129.25it/s, loss=0.362, metrics={'acc': 0.8341, 'prec': 0.6947}]\n",
"epoch 2: 100%|██████████| 611/611 [00:06<00:00, 93.72it/s, loss=0.362, metrics={'acc': 0.8288, 'prec': 0.6783}] \n", "epoch 2: 100%|██████████| 611/611 [00:07<00:00, 79.96it/s, loss=0.373, metrics={'acc': 0.8245, 'prec': 0.6621}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 159.87it/s, loss=0.355, metrics={'acc': 0.8354, 'prec': 0.6706}]\n", "valid: 100%|██████████| 153/153 [00:01<00:00, 140.03it/s, loss=0.356, metrics={'acc': 0.8353, 'prec': 0.6742}]\n",
"epoch 3: 100%|██████████| 611/611 [00:06<00:00, 98.04it/s, loss=0.356, metrics={'acc': 0.8333, 'prec': 0.6885}] \n", "epoch 3: 100%|██████████| 611/611 [00:07<00:00, 79.08it/s, loss=0.364, metrics={'acc': 0.8288, 'prec': 0.6729}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 153.29it/s, loss=0.352, metrics={'acc': 0.8357, 'prec': 0.6653}]\n", "valid: 100%|██████████| 153/153 [00:01<00:00, 150.18it/s, loss=0.35, metrics={'acc': 0.838, 'prec': 0.6875}] \n",
"epoch 4: 100%|██████████| 611/611 [00:06<00:00, 97.35it/s, loss=0.351, metrics={'acc': 0.8349, 'prec': 0.6918}] \n", "epoch 4: 100%|██████████| 611/611 [00:07<00:00, 82.86it/s, loss=0.358, metrics={'acc': 0.8319, 'prec': 0.6814}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 157.49it/s, loss=0.351, metrics={'acc': 0.8368, 'prec': 0.6703}]\n", "valid: 100%|██████████| 153/153 [00:01<00:00, 147.48it/s, loss=0.345, metrics={'acc': 0.8394, 'prec': 0.6949}]\n",
"epoch 5: 100%|██████████| 611/611 [00:06<00:00, 96.26it/s, loss=0.35, metrics={'acc': 0.8359, 'prec': 0.6943}] \n", "epoch 5: 100%|██████████| 611/611 [00:07<00:00, 78.20it/s, loss=0.354, metrics={'acc': 0.8337, 'prec': 0.6872}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 148.74it/s, loss=0.348, metrics={'acc': 0.8361, 'prec': 0.6723}]\n" "valid: 100%|██████████| 153/153 [00:01<00:00, 150.62it/s, loss=0.344, metrics={'acc': 0.8426, 'prec': 0.7066}]\n"
] ]
} }
], ],
...@@ -646,16 +647,16 @@ ...@@ -646,16 +647,16 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 611/611 [00:08<00:00, 71.51it/s, loss=0.398, metrics={'acc': 0.811, 'prec': 0.6357}] \n", "epoch 1: 100%|██████████| 611/611 [00:08<00:00, 67.98it/s, loss=0.385, metrics={'acc': 0.8182, 'prec': 0.6465}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 139.38it/s, loss=0.36, metrics={'acc': 0.832, 'prec': 0.6827}] \n", "valid: 100%|██████████| 153/153 [00:01<00:00, 146.93it/s, loss=0.359, metrics={'acc': 0.8361, 'prec': 0.6862}]\n",
"epoch 2: 100%|██████████| 611/611 [00:08<00:00, 73.47it/s, loss=0.369, metrics={'acc': 0.8241, 'prec': 0.6637}]\n", "epoch 2: 100%|██████████| 611/611 [00:09<00:00, 67.33it/s, loss=0.363, metrics={'acc': 0.8296, 'prec': 0.6756}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 103.06it/s, loss=0.354, metrics={'acc': 0.836, 'prec': 0.6872}] \n", "valid: 100%|██████████| 153/153 [00:01<00:00, 148.88it/s, loss=0.354, metrics={'acc': 0.8353, 'prec': 0.7058}]\n",
"epoch 3: 100%|██████████| 611/611 [00:10<00:00, 59.24it/s, loss=0.36, metrics={'acc': 0.8287, 'prec': 0.6749}] \n", "epoch 3: 100%|██████████| 611/611 [00:09<00:00, 67.31it/s, loss=0.357, metrics={'acc': 0.8312, 'prec': 0.6822}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 151.05it/s, loss=0.351, metrics={'acc': 0.838, 'prec': 0.6892}] \n", "valid: 100%|██████████| 153/153 [00:01<00:00, 142.79it/s, loss=0.351, metrics={'acc': 0.8387, 'prec': 0.6813}]\n",
"epoch 4: 100%|██████████| 611/611 [00:08<00:00, 71.69it/s, loss=0.354, metrics={'acc': 0.8324, 'prec': 0.682}] \n", "epoch 4: 100%|██████████| 611/611 [00:09<00:00, 62.17it/s, loss=0.353, metrics={'acc': 0.8347, 'prec': 0.6897}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 158.34it/s, loss=0.348, metrics={'acc': 0.8397, 'prec': 0.7025}]\n", "valid: 100%|██████████| 153/153 [00:01<00:00, 124.27it/s, loss=0.348, metrics={'acc': 0.8404, 'prec': 0.692}] \n",
"epoch 5: 100%|██████████| 611/611 [00:09<00:00, 62.20it/s, loss=0.351, metrics={'acc': 0.8333, 'prec': 0.6845}]\n", "epoch 5: 100%|██████████| 611/611 [00:17<00:00, 35.55it/s, loss=0.35, metrics={'acc': 0.8347, 'prec': 0.6893}] \n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 122.12it/s, loss=0.348, metrics={'acc': 0.8397, 'prec': 0.6959}]\n" "valid: 100%|██████████| 153/153 [00:01<00:00, 116.47it/s, loss=0.345, metrics={'acc': 0.8427, 'prec': 0.6936}]\n"
] ]
} }
], ],
...@@ -667,7 +668,7 @@ ...@@ -667,7 +668,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Using `TabTransformer` as the `deeptabular` component" "Using the `FT-Transformer` as the `deeptabular` component"
] ]
}, },
{ {
...@@ -702,42 +703,36 @@ ...@@ -702,42 +703,36 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 18,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/javier/Projects/pytorch-widedeep/pytorch_widedeep/preprocessing/tab_preprocessor.py:146: UserWarning: Both 'for_tabtransformer' and 'scale' are set to True. This implies that the continuous columns will be standarized and then passed through a LayerNorm layer\n",
" UserWarning,\n"
]
}
],
"source": [ "source": [
"# deeptabular\n", "# deeptabular\n",
"tab_preprocessor = TabPreprocessor(embed_cols=cat_embed_cols_for_transformer, \n", "tab_preprocessor = TabPreprocessor(embed_cols=cat_embed_cols_for_transformer, \n",
" continuous_cols=continuous_cols, \n", " continuous_cols=continuous_cols, \n",
" for_tabtransformer=True)\n", " for_transformer=True, \n",
" with_cls_token = True) # you need to define this since it changes pre-processing\n",
"X_tab = tab_preprocessor.fit_transform(df)" "X_tab = tab_preprocessor.fit_transform(df)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 19,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)\n", "wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)\n",
"deeptabular = TabTransformer(column_idx=tab_preprocessor.column_idx,\n", "deeptabular = TabTransformer(column_idx=tab_preprocessor.column_idx,\n",
" embed_input=tab_preprocessor.embeddings_input,\n", " embed_input=tab_preprocessor.embeddings_input,\n",
" continuous_cols=continuous_cols)\n", " continuous_cols=continuous_cols, \n",
" embed_continuous = True, \n",
" embed_continuous_activation = \"relu\")\n",
"model = WideDeep(wide=wide, deeptabular=deeptabular)" "model = WideDeep(wide=wide, deeptabular=deeptabular)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 20,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -753,10 +748,10 @@ ...@@ -753,10 +748,10 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 77/77 [00:14<00:00, 5.14it/s, loss=0.356, metrics={'acc': 0.8318, 'prec': 0.6798}]\n", "epoch 1: 100%|██████████| 77/77 [00:21<00:00, 3.57it/s, loss=0.446, metrics={'acc': 0.7905, 'prec': 0.5849}]\n",
"valid: 100%|██████████| 20/20 [00:01<00:00, 14.60it/s, loss=0.352, metrics={'acc': 0.8361, 'prec': 0.6662}]\n", "valid: 100%|██████████| 20/20 [00:01<00:00, 13.74it/s, loss=0.374, metrics={'acc': 0.8227, 'prec': 0.6443}]\n",
"epoch 2: 100%|██████████| 77/77 [00:15<00:00, 5.02it/s, loss=0.353, metrics={'acc': 0.8338, 'prec': 0.6832}]\n", "epoch 2: 100%|██████████| 77/77 [00:21<00:00, 3.63it/s, loss=0.377, metrics={'acc': 0.8231, 'prec': 0.6586}]\n",
"valid: 100%|██████████| 20/20 [00:01<00:00, 14.68it/s, loss=0.352, metrics={'acc': 0.8359, 'prec': 0.66}] \n" "valid: 100%|██████████| 20/20 [00:01<00:00, 14.04it/s, loss=0.372, metrics={'acc': 0.8216, 'prec': 0.6112}]\n"
] ]
} }
], ],
...@@ -798,16 +793,16 @@ ...@@ -798,16 +793,16 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 611/611 [00:03<00:00, 160.60it/s, loss=0.603, metrics={'acc': 0.6775, 'prec': 0.3213}]\n", "epoch 1: 100%|██████████| 611/611 [00:03<00:00, 159.50it/s, loss=0.46, metrics={'acc': 0.7836, 'prec': 0.573}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 212.39it/s, loss=0.483, metrics={'acc': 0.7679, 'prec': 0.528}] \n", "valid: 100%|██████████| 153/153 [00:00<00:00, 168.08it/s, loss=0.422, metrics={'acc': 0.805, 'prec': 0.6403}] \n",
"epoch 2: 100%|██████████| 611/611 [00:03<00:00, 174.65it/s, loss=0.448, metrics={'acc': 0.7904, 'prec': 0.6199}]\n", "epoch 2: 100%|██████████| 611/611 [00:03<00:00, 158.35it/s, loss=0.405, metrics={'acc': 0.8131, 'prec': 0.6643}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 224.30it/s, loss=0.418, metrics={'acc': 0.807, 'prec': 0.6645}] \n", "valid: 100%|██████████| 153/153 [00:00<00:00, 214.83it/s, loss=0.394, metrics={'acc': 0.8168, 'prec': 0.6741}]\n",
"epoch 3: 100%|██████████| 611/611 [00:03<00:00, 171.36it/s, loss=0.405, metrics={'acc': 0.8142, 'prec': 0.6749}]\n", "epoch 3: 100%|██████████| 611/611 [00:03<00:00, 162.96it/s, loss=0.385, metrics={'acc': 0.8201, 'prec': 0.6837}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 221.29it/s, loss=0.393, metrics={'acc': 0.8231, 'prec': 0.6871}]\n", "valid: 100%|██████████| 153/153 [00:00<00:00, 210.59it/s, loss=0.381, metrics={'acc': 0.8228, 'prec': 0.6799}]\n",
"epoch 4: 100%|██████████| 611/611 [00:03<00:00, 174.61it/s, loss=0.386, metrics={'acc': 0.8237, 'prec': 0.6927}]\n", "epoch 4: 100%|██████████| 611/611 [00:03<00:00, 171.06it/s, loss=0.375, metrics={'acc': 0.8256, 'prec': 0.6895}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 220.15it/s, loss=0.38, metrics={'acc': 0.8295, 'prec': 0.7002}] \n", "valid: 100%|██████████| 153/153 [00:00<00:00, 210.97it/s, loss=0.374, metrics={'acc': 0.8259, 'prec': 0.6798}]\n",
"epoch 5: 100%|██████████| 611/611 [00:03<00:00, 173.50it/s, loss=0.376, metrics={'acc': 0.8269, 'prec': 0.6941}]\n", "epoch 5: 100%|██████████| 611/611 [00:03<00:00, 157.49it/s, loss=0.369, metrics={'acc': 0.828, 'prec': 0.692}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 215.49it/s, loss=0.373, metrics={'acc': 0.8321, 'prec': 0.697}] \n" "valid: 100%|██████████| 153/153 [00:00<00:00, 197.19it/s, loss=0.37, metrics={'acc': 0.8275, 'prec': 0.6856}] \n"
] ]
} }
], ],
......
...@@ -541,6 +541,7 @@ ...@@ -541,6 +541,7 @@
" (emb_layer_workclass): Embedding(10, 16, padding_idx=0)\n", " (emb_layer_workclass): Embedding(10, 16, padding_idx=0)\n",
" )\n", " )\n",
" (embedding_dropout): Dropout(p=0.1, inplace=False)\n", " (embedding_dropout): Dropout(p=0.1, inplace=False)\n",
" (cont_norm): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (tab_mlp): MLP(\n", " (tab_mlp): MLP(\n",
" (mlp): Sequential(\n", " (mlp): Sequential(\n",
" (dense_layer_0): Sequential(\n", " (dense_layer_0): Sequential(\n",
...@@ -657,26 +658,26 @@ ...@@ -657,26 +658,26 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 153/153 [00:03<00:00, 47.06it/s, loss=0.667, metrics={'acc': 0.7471, 'rec': 0.4645}]\n", "epoch 1: 100%|██████████| 153/153 [00:03<00:00, 40.76it/s, loss=0.605, metrics={'acc': 0.7653, 'rec': 0.5005}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 109.05it/s, loss=0.384, metrics={'acc': 0.8328, 'rec': 0.5701}]\n", "valid: 100%|██████████| 39/39 [00:00<00:00, 70.56it/s, loss=0.37, metrics={'acc': 0.8295, 'rec': 0.5646}] \n",
"epoch 2: 100%|██████████| 153/153 [00:03<00:00, 47.77it/s, loss=0.384, metrics={'acc': 0.8241, 'rec': 0.56}] \n", "epoch 2: 100%|██████████| 153/153 [00:03<00:00, 42.82it/s, loss=0.37, metrics={'acc': 0.8298, 'rec': 0.5627}] \n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 103.34it/s, loss=0.363, metrics={'acc': 0.8354, 'rec': 0.5838}]\n", "valid: 100%|██████████| 39/39 [00:00<00:00, 116.22it/s, loss=0.355, metrics={'acc': 0.8372, 'rec': 0.6206}]\n",
"epoch 3: 100%|██████████| 153/153 [00:02<00:00, 51.60it/s, loss=0.359, metrics={'acc': 0.8338, 'rec': 0.5657}]\n", "epoch 3: 100%|██████████| 153/153 [00:03<00:00, 41.82it/s, loss=0.354, metrics={'acc': 0.8338, 'rec': 0.5612}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 116.99it/s, loss=0.357, metrics={'acc': 0.8365, 'rec': 0.5719}]\n", "valid: 100%|██████████| 39/39 [00:00<00:00, 116.42it/s, loss=0.35, metrics={'acc': 0.8395, 'rec': 0.5804}] \n",
"epoch 4: 100%|██████████| 153/153 [00:03<00:00, 50.11it/s, loss=0.349, metrics={'acc': 0.8376, 'rec': 0.5608}]\n", "epoch 4: 100%|██████████| 153/153 [00:03<00:00, 42.66it/s, loss=0.345, metrics={'acc': 0.8382, 'rec': 0.5658}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 114.85it/s, loss=0.355, metrics={'acc': 0.8374, 'rec': 0.595}] \n", "valid: 100%|██████████| 39/39 [00:00<00:00, 115.17it/s, loss=0.35, metrics={'acc': 0.8379, 'rec': 0.6048}] \n",
"epoch 5: 100%|██████████| 153/153 [00:03<00:00, 45.94it/s, loss=0.347, metrics={'acc': 0.8377, 'rec': 0.5624}]\n", "epoch 5: 100%|██████████| 153/153 [00:03<00:00, 42.11it/s, loss=0.343, metrics={'acc': 0.8391, 'rec': 0.5681}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 119.71it/s, loss=0.355, metrics={'acc': 0.8384, 'rec': 0.6091}]\n", "valid: 100%|██████████| 39/39 [00:00<00:00, 115.60it/s, loss=0.347, metrics={'acc': 0.84, 'rec': 0.595}] \n",
"epoch 6: 100%|██████████| 153/153 [00:03<00:00, 47.17it/s, loss=0.346, metrics={'acc': 0.8377, 'rec': 0.5655}]\n", "epoch 6: 100%|██████████| 153/153 [00:03<00:00, 41.32it/s, loss=0.341, metrics={'acc': 0.8398, 'rec': 0.5748}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 122.04it/s, loss=0.352, metrics={'acc': 0.8403, 'rec': 0.589}] \n", "valid: 100%|██████████| 39/39 [00:00<00:00, 109.95it/s, loss=0.347, metrics={'acc': 0.8404, 'rec': 0.5855}]\n",
"epoch 7: 100%|██████████| 153/153 [00:03<00:00, 50.84it/s, loss=0.344, metrics={'acc': 0.8402, 'rec': 0.573}] \n", "epoch 7: 100%|██████████| 153/153 [00:03<00:00, 41.79it/s, loss=0.34, metrics={'acc': 0.8413, 'rec': 0.5746}] \n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 122.60it/s, loss=0.352, metrics={'acc': 0.8394, 'rec': 0.5796}]\n", "valid: 100%|██████████| 39/39 [00:00<00:00, 108.11it/s, loss=0.347, metrics={'acc': 0.8395, 'rec': 0.5898}]\n",
"epoch 8: 100%|██████████| 153/153 [00:03<00:00, 47.53it/s, loss=0.343, metrics={'acc': 0.8393, 'rec': 0.5696}]\n", "epoch 8: 100%|██████████| 153/153 [00:03<00:00, 41.09it/s, loss=0.341, metrics={'acc': 0.8395, 'rec': 0.5744}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 112.23it/s, loss=0.352, metrics={'acc': 0.839, 'rec': 0.5808}] \n", "valid: 100%|██████████| 39/39 [00:00<00:00, 99.26it/s, loss=0.347, metrics={'acc': 0.8404, 'rec': 0.5877}]\n",
"epoch 9: 100%|██████████| 153/153 [00:03<00:00, 45.08it/s, loss=0.343, metrics={'acc': 0.8405, 'rec': 0.5746}]\n", "epoch 9: 100%|██████████| 153/153 [00:03<00:00, 41.33it/s, loss=0.34, metrics={'acc': 0.8409, 'rec': 0.573}] \n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 81.40it/s, loss=0.351, metrics={'acc': 0.839, 'rec': 0.586}] \n", "valid: 100%|██████████| 39/39 [00:00<00:00, 108.59it/s, loss=0.347, metrics={'acc': 0.8399, 'rec': 0.5778}]\n",
"epoch 10: 100%|██████████| 153/153 [00:04<00:00, 34.82it/s, loss=0.343, metrics={'acc': 0.8408, 'rec': 0.5745}]\n", "epoch 10: 100%|██████████| 153/153 [00:03<00:00, 40.06it/s, loss=0.34, metrics={'acc': 0.8413, 'rec': 0.5718}] \n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 96.99it/s, loss=0.351, metrics={'acc': 0.8387, 'rec': 0.5761}] \n" "valid: 100%|██████████| 39/39 [00:00<00:00, 104.13it/s, loss=0.347, metrics={'acc': 0.8395, 'rec': 0.577}] \n"
] ]
}, },
{ {
...@@ -707,7 +708,7 @@ ...@@ -707,7 +708,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"{'train_loss': [0.6669373385656893, 0.38398634666710896, 0.358561675727757, 0.3492486531438391, 0.3467087533349305, 0.34565913872001996, 0.3435389565096961, 0.3430960536782258, 0.34279162490289977, 0.34314920641238394], 'train_acc': [0.7471143756558237, 0.8241496685690887, 0.833849461264812, 0.8375604637473447, 0.837739615591329, 0.837739615591329, 0.8401709620454022, 0.8393007959460497, 0.8405036726128017, 0.8407596038184936], 'train_rec': [0.4645416736602783, 0.5599529147148132, 0.5657289624214172, 0.5608086585998535, 0.5624130964279175, 0.5655150413513184, 0.5730024576187134, 0.5695796608924866, 0.5746068954467773, 0.5744999647140503], 'val_loss': [0.3841508023249797, 0.36289690396724605, 0.35651543507209194, 0.3551729489595462, 0.3552632217223828, 0.35234999580261034, 0.3519755387917543, 0.3516103594731062, 0.3511298107795226, 0.35126262444716233], 'val_acc': [0.8328385709898659, 0.8353976865595251, 0.8365236974101751, 0.8374449790152523, 0.8383662606203296, 0.8403111884532706, 0.8393899068481933, 0.8389804483570478, 0.8389804483570478, 0.8386733544886887], 'val_rec': [0.5701454281806946, 0.583832323551178, 0.5718562602996826, 0.5949529409408569, 0.6090675592422485, 0.5889649391174316, 0.5795551538467407, 0.5808383226394653, 0.585970938205719, 0.5761334300041199]}\n" "{'train_loss': [0.6051353691450132, 0.3695722280764112, 0.35393014296986697, 0.3445955140917909, 0.34339318926038304, 0.3414274424898858, 0.33967164684744444, 0.3409811920589871, 0.3399035648193235, 0.3402750544688281], 'train_acc': [0.7653366775010877, 0.8297545619737414, 0.8337726819031045, 0.8382002917615745, 0.8391216441020654, 0.8397614721162951, 0.8412970593504466, 0.8395055409106033, 0.8408619763007703, 0.8412970593504466], 'train_rec': [0.5004813075065613, 0.5627340078353882, 0.5612365007400513, 0.5658358931541443, 0.5680821537971497, 0.5748208165168762, 0.5746068954467773, 0.5743929743766785, 0.5730024576187134, 0.5718258619308472], 'val_loss': [0.37032382075603193, 0.35480272387846923, 0.349816164909265, 0.3500071618801508, 0.3474775743790162, 0.3471213915409186, 0.34703178054247147, 0.3469086617995531, 0.3469338050255409, 0.3465542976672833], 'val_acc': [0.8294605384379159, 0.8372402497696796, 0.8394922714709796, 0.8378544375063978, 0.8400040945849114, 0.8404135530760569, 0.8394922714709796, 0.8404135530760569, 0.8399017299621251, 0.8394922714709796], 'val_rec': [0.5645850896835327, 0.6206158995628357, 0.5804105997085571, 0.6047903895378113, 0.5949529409408569, 0.5855432152748108, 0.589820384979248, 0.587681770324707, 0.5778443217277527, 0.5769888758659363]}\n"
] ]
} }
], ],
...@@ -749,70 +750,70 @@ ...@@ -749,70 +750,70 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'11th': array([-0.33080965, -0.32130778, 0.19615164, 0.13556953, 0.14851215,\n", "{'11th': array([ 0.17823647, 0.04097574, -0.16298912, -0.11065536, 0.1432162 ,\n",
" 0.13719487, -0.16899936, 0.3414919 , -0.19812936, -0.04245757,\n", " 0.10531982, -0.34251764, 0.40085673, 0.09578304, 0.15393786,\n",
" 0.351975 , 0.20050879, 0.31216618, -0.5121311 , 0.43086612,\n", " 0.26928946, -0.05603978, -0.2609236 , 0.0091235 , 0.07494199,\n",
" 0.53559804], dtype=float32),\n", " 0.02190116], dtype=float32),\n",
" 'HS-grad': array([-0.16151328, 0.23065983, -0.00747482, 0.12380929, 0.12916717,\n", " 'HS-grad': array([ 0.120097 , -0.13213032, -0.0592633 , 0.04583196, -0.04858546,\n",
" -0.06022112, -0.10307326, 0.00917153, 0.08243465, -0.3448795 ,\n", " -0.39242733, -0.43368143, 0.00434827, 0.04477202, 0.07125217,\n",
" 0.11542231, 0.37100902, -0.06485075, 0.2418356 , 0.05066948,\n", " -0.15088314, -0.2939101 , 0.31975606, -0.341947 , 0.22773097,\n",
" -0.22904117], dtype=float32),\n", " 0.28342503], dtype=float32),\n",
" 'Assoc-acdm': array([ 0.15888667, 0.05965156, -0.07115431, -0.15109068, -0.10918076,\n", " 'Assoc-acdm': array([ 0.14957719, -0.18953936, -0.22840326, 0.45375347, 0.26669678,\n",
" -0.05370647, -0.21373998, 0.30832142, -0.29594633, 0.20489004,\n", " 0.05090672, 0.46574584, 0.2774832 , -0.12203862, 0.13699052,\n",
" 0.18331873, 0.18607427, -0.19573471, -0.10851611, 0.30840558,\n", " -0.27128282, -0.34413835, 0.29697102, 0.12395442, 0.14231798,\n",
" -0.22811931], dtype=float32),\n", " -0.10790487], dtype=float32),\n",
" 'Some-college': array([-0.28310186, 0.2633668 , 0.21454443, -0.53031933, -0.28034252,\n", " 'Some-college': array([-0.2126067 , 0.04664122, -0.15191978, -0.10957965, -0.12881616,\n",
" -0.2035272 , -0.09638138, -0.09200009, -0.04980965, -0.3110542 ,\n", " -0.04466751, 0.25502843, 0.32889867, 0.0168101 , 0.20086999,\n",
" -0.0603297 , 0.07123059, -0.13342348, 0.08563767, -0.45675036,\n", " -0.21912436, -0.00544369, 0.03351 , -0.17859232, 0.1382413 ,\n",
" 0.11884805], dtype=float32),\n", " 0.26502082], dtype=float32),\n",
" '10th': array([-0.12832558, -0.6697091 , -0.14748466, -0.0828823 , 0.42625666,\n", " '10th': array([-0.3121446 , 0.19805874, -0.03366002, 0.1288065 , 0.26396075,\n",
" -0.0548928 , -0.03520742, -0.29257616, 0.16213404, -0.10930534,\n", " -0.05587888, 0.22792356, -0.06681106, 0.12476017, 0.37026265,\n",
" 0.04222836, -0.02641041, 0.04289234, -0.21711953, 0.29988793,\n", " 0.03204104, -0.09612755, 0.0324997 , -0.08246089, 0.04117873,\n",
" -0.03914425], dtype=float32),\n", " 0.1853117 ], dtype=float32),\n",
" 'Prof-school': array([-0.37946448, -0.33668968, 0.4209301 , 0.01543873, 0.17252892,\n", " 'Prof-school': array([-0.4429325 , -0.12834997, 0.3658504 , 0.48140833, 0.11574885,\n",
" 0.30649292, 0.02050802, 0.11289834, -0.02656657, -0.24030082,\n", " -0.192547 , 0.1586941 , -0.2919336 , 0.1567621 , 0.29656097,\n",
" -0.32401893, 0.1216528 , 0.01666113, 0.09014473, -0.06262966,\n", " 0.18974394, 0.06253866, 0.16234514, -0.08963383, -0.08024175,\n",
" -0.04359816], dtype=float32),\n", " 0.54286146], dtype=float32),\n",
" '7th-8th': array([-0.16373637, -0.48016912, 0.7061354 , -0.02272724, -0.3922257 ,\n", " '7th-8th': array([ 0.54942334, 0.37394103, -0.03598195, -0.05772773, -0.28254417,\n",
" 0.09848034, -0.07539301, -0.03332449, 0.36090446, -0.29780784,\n", " 0.54470855, -0.6513119 , -0.13811558, -0.11478714, 0.06010893,\n",
" 0.2516029 , 0.02890976, 0.26914543, 0.22876695, -0.29285768,\n", " -0.2462508 , 0.1755247 , 0.10117105, 0.36358032, -0.09656113,\n",
" -0.3230711 ], dtype=float32),\n", " 0.34954002], dtype=float32),\n",
" 'Bachelors': array([ 0.02716259, 0.01603995, 0.11296882, 0.09937789, -0.15349191,\n", " 'Bachelors': array([ 0.06564163, -0.23048915, -0.12470629, -0.02602417, 0.35001647,\n",
" 0.36589023, -0.22568569, -0.06556027, 0.03241782, -0.03301224,\n", " -0.18802756, 0.10905975, -0.33273023, 0.01738172, 0.2478116 ,\n",
" -0.15650426, 0.33991587, -0.08569644, -0.0560803 , -0.22597635,\n", " 0.00981276, -0.18224423, 0.0950555 , 0.17849174, 0.17942917,\n",
" -0.08900049], dtype=float32),\n", " 0.31124604], dtype=float32),\n",
" 'Masters': array([ 0.20437175, -0.10703096, -0.06892612, 0.34738615, -0.11776404,\n", " 'Masters': array([ 0.13041618, -0.07283561, -0.34077218, 0.05142086, 0.08315329,\n",
" -0.38715032, 0.10983769, 0.39286137, -0.43856898, 0.23008673,\n", " -0.12212724, 0.31239262, -0.20927685, -0.24285726, 0.06567737,\n",
" 0.08582868, 0.11090949, -0.2212543 , -0.1813675 , -0.34878278,\n", " 0.03671836, -0.03405587, 0.01641322, 0.17043172, -0.38756114,\n",
" -0.06638747], dtype=float32),\n", " 0.30868122], dtype=float32),\n",
" 'Doctorate': array([-0.36630204, -0.03157847, -0.21379165, -0.3201302 , -0.18573532,\n", " 'Doctorate': array([-0.10755017, -0.03946237, -0.5153946 , 0.23642367, -0.4680825 ,\n",
" -0.48710173, 0.3539162 , -0.26378518, -0.22426853, -0.00657644,\n", " 0.2587171 , -0.1300325 , -0.05143512, -0.20121185, -0.02474 ,\n",
" 0.11982051, 0.44674066, 0.02232701, 0.2919437 , -0.05375196,\n", " -0.09320115, -0.07455952, 0.10833438, -0.02096028, -0.12492044,\n",
" -0.16767474], dtype=float32),\n", " 0.00582709], dtype=float32),\n",
" '5th-6th': array([ 0.23281994, -0.43167526, 0.0894226 , -0.11401146, 0.02560319,\n", " '5th-6th': array([-0.12893526, 0.27144003, 0.37272307, 0.3963532 , 0.34640747,\n",
" -0.19197138, 0.43012407, -0.26886204, 0.17845912, 0.36679146,\n", " -0.33437288, 0.0193824 , -0.01519158, -0.42908698, 0.05110272,\n",
" -0.23775795, 0.22599484, -0.03654391, 0.14109942, 0.21356976,\n", " 0.01151075, 0.15922028, -0.17880926, -0.36683136, -0.40467307,\n",
" 0.09289707], dtype=float32),\n", " -0.12017028], dtype=float32),\n",
" 'Assoc-voc': array([-0.33012867, 0.03091485, 0.0463426 , -0.26056498, 0.14094928,\n", " 'Assoc-voc': array([ 0.02241084, -0.07670853, -0.22828907, -0.12371975, -0.07486907,\n",
" 0.07926508, 0.23741004, -0.1044822 , 0.2198133 , 0.0026662 ,\n", " -0.29233935, 0.31587106, 0.2165355 , 0.20171323, -0.15870345,\n",
" -0.20316987, -0.5384259 , 0.2922033 , -0.1975719 , -0.34955254,\n", " -0.1275358 , -0.21006238, -0.03274518, -0.14725143, -0.213672 ,\n",
" 0.4160792 ], dtype=float32),\n", " 0.30866137], dtype=float32),\n",
" '9th': array([-0.5900627 , 0.05397911, -0.08968943, 0.01747594, 0.02829137,\n", " '9th': array([ 0.1470835 , -0.0528347 , 0.24995384, -0.21315503, -0.24470845,\n",
" -0.35849443, -0.3720297 , 0.3774286 , 0.41263312, 0.43744153,\n", " 0.819329 , 0.04469828, 0.09546001, 0.24664721, 0.3054443 ,\n",
" -0.52274925, -0.32064512, 0.3371732 , 0.08183856, -0.35107607,\n", " 0.4566717 , 0.14872263, 0.0116579 , 0.2515947 , 0.2023506 ,\n",
" -0.01041517], dtype=float32),\n", " -0.3379088 ], dtype=float32),\n",
" '12th': array([ 0.25146627, 0.25219554, 0.02975183, 0.05702514, -0.24613279,\n", " '12th': array([-0.01843497, 0.21602574, -0.35730916, -0.16129005, 0.34858495,\n",
" -0.21778086, -0.00569818, -0.0637485 , -0.1616422 , -0.30637848,\n", " 0.07911005, -0.09155226, 0.25502652, -0.20713754, -0.2009355 ,\n",
" 0.12083519, -0.14571959, -0.16722937, -0.15352033, -0.4982682 ,\n", " -0.18680803, 0.05695441, -0.20793928, -0.01325957, -0.28487244,\n",
" -0.28606132], dtype=float32),\n", " 0.26250076], dtype=float32),\n",
" '1st-4th': array([ 0.39057913, 0.25798413, -0.4903637 , -0.59042794, -0.02872112,\n", " '1st-4th': array([-0.5274408 , -0.17692605, -0.32478535, -0.15695599, 0.03235544,\n",
" 0.03271537, 0.3033741 , 0.18191384, -0.0404124 , 0.17591539,\n", " -0.37266013, 0.35468644, 0.16074362, -0.36835802, 0.37510112,\n",
" 0.09823196, 0.09790152, 0.1655496 , 0.2734208 , 0.10942121,\n", " 0.0420665 , -0.19070098, 0.33601463, -0.4323496 , -0.19171081,\n",
" 0.11239085], dtype=float32),\n", " -0.27081746], dtype=float32),\n",
" 'Preschool': array([-0.1641325 , -0.25642204, -0.14070551, 0.51633525, -0.04863334,\n", " 'Preschool': array([ 0.07924446, 0.11405066, -0.17461444, -0.11104126, 0.45389435,\n",
" -0.5472205 , 0.21776134, 0.08466767, 0.167773 , -0.22240794,\n", " -0.06884138, -0.07859107, 0.30992216, -0.09668542, -0.03197801,\n",
" 0.08511946, 0.19921534, 0.6440741 , 0.34729373, 0.43439448,\n", " 0.25111035, 0.5209666 , 0.61060447, 0.03642207, 0.05149668,\n",
" -0.4937917 ], dtype=float32)}" " 0.14839056], dtype=float32)}"
] ]
}, },
"execution_count": 18, "execution_count": 18,
...@@ -841,7 +842,7 @@ ...@@ -841,7 +842,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.6" "version": "3.7.7"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
...@@ -1069,7 +1069,7 @@ ...@@ -1069,7 +1069,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
" 4%|▍ | 41/1001 [00:00<00:02, 405.97it/s]" " 4%|▍ | 40/1001 [00:00<00:02, 392.70it/s]"
] ]
}, },
{ {
...@@ -1083,7 +1083,7 @@ ...@@ -1083,7 +1083,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"100%|██████████| 1001/1001 [00:02<00:00, 388.42it/s]\n" "100%|██████████| 1001/1001 [00:02<00:00, 382.63it/s]\n"
] ]
}, },
{ {
...@@ -1124,7 +1124,7 @@ ...@@ -1124,7 +1124,7 @@
" embed_input=tab_preprocessor.embeddings_input,\n", " embed_input=tab_preprocessor.embeddings_input,\n",
" embed_dropout = 0.1,\n", " embed_dropout = 0.1,\n",
" continuous_cols = continuous_cols,\n", " continuous_cols = continuous_cols,\n",
" batchnorm_cont = True\n", " cont_norm_layer = \"batchnorm\"\n",
")\n", ")\n",
" \n", " \n",
"# DeepText: a stack of 2 LSTMs\n", "# DeepText: a stack of 2 LSTMs\n",
...@@ -1173,15 +1173,15 @@ ...@@ -1173,15 +1173,15 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 13,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 25/25 [02:06<00:00, 5.08s/it, loss=116]\n", "epoch 1: 100%|██████████| 25/25 [02:05<00:00, 5.03s/it, loss=107]\n",
"valid: 100%|██████████| 7/7 [00:16<00:00, 2.30s/it, loss=100] \n" "valid: 100%|██████████| 7/7 [00:15<00:00, 2.22s/it, loss=129]\n"
] ]
} }
], ],
...@@ -1201,7 +1201,7 @@ ...@@ -1201,7 +1201,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 14,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1215,7 +1215,7 @@ ...@@ -1215,7 +1215,7 @@
" embed_input=tab_preprocessor.embeddings_input,\n", " embed_input=tab_preprocessor.embeddings_input,\n",
" embed_dropout = 0.1,\n", " embed_dropout = 0.1,\n",
" continuous_cols = continuous_cols,\n", " continuous_cols = continuous_cols,\n",
" batchnorm_cont = True\n", " cont_norm_layer = \"batchnorm\"\n",
")\n", ")\n",
"\n", "\n",
"deeptext = DeepText(\n", "deeptext = DeepText(\n",
...@@ -1238,7 +1238,7 @@ ...@@ -1238,7 +1238,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 15,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1260,7 +1260,7 @@ ...@@ -1260,7 +1260,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 16,
"metadata": { "metadata": {
"scrolled": false "scrolled": false
}, },
...@@ -1285,7 +1285,7 @@ ...@@ -1285,7 +1285,7 @@
" (emb_layer_neighbourhood_cleansed): Embedding(33, 64, padding_idx=0)\n", " (emb_layer_neighbourhood_cleansed): Embedding(33, 64, padding_idx=0)\n",
" )\n", " )\n",
" (embedding_dropout): Dropout(p=0.1, inplace=False)\n", " (embedding_dropout): Dropout(p=0.1, inplace=False)\n",
" (norm): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (cont_norm): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (tab_mlp): MLP(\n", " (tab_mlp): MLP(\n",
" (mlp): Sequential(\n", " (mlp): Sequential(\n",
" (dense_layer_0): Sequential(\n", " (dense_layer_0): Sequential(\n",
...@@ -1423,7 +1423,7 @@ ...@@ -1423,7 +1423,7 @@
")" ")"
] ]
}, },
"execution_count": 15, "execution_count": 16,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -1443,7 +1443,7 @@ ...@@ -1443,7 +1443,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 17,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1457,7 +1457,7 @@ ...@@ -1457,7 +1457,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 18,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1470,7 +1470,7 @@ ...@@ -1470,7 +1470,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 19,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1483,7 +1483,7 @@ ...@@ -1483,7 +1483,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 20,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1509,7 +1509,7 @@ ...@@ -1509,7 +1509,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": 21,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -1535,15 +1535,15 @@ ...@@ -1535,15 +1535,15 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 21, "execution_count": 22,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 25/25 [02:06<00:00, 5.05s/it, loss=111]\n", "epoch 1: 100%|██████████| 25/25 [02:09<00:00, 5.18s/it, loss=106]\n",
"valid: 100%|██████████| 7/7 [00:15<00:00, 2.24s/it, loss=95.8]" "valid: 100%|██████████| 7/7 [00:16<00:00, 2.31s/it, loss=95.5]"
] ]
}, },
{ {
...@@ -1575,7 +1575,7 @@ ...@@ -1575,7 +1575,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22, "execution_count": 23,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -1603,7 +1603,7 @@ ...@@ -1603,7 +1603,7 @@
" 'lr_deephead_0': [0.001, 0.001]}" " 'lr_deephead_0': [0.001, 0.001]}"
] ]
}, },
"execution_count": 22, "execution_count": 23,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
......
...@@ -276,7 +276,8 @@ ...@@ -276,7 +276,8 @@
"deeptabular = TabMlp(mlp_hidden_dims=[64,32], \n", "deeptabular = TabMlp(mlp_hidden_dims=[64,32], \n",
" column_idx=tab_preprocessor.column_idx,\n", " column_idx=tab_preprocessor.column_idx,\n",
" embed_input=tab_preprocessor.embeddings_input,\n", " embed_input=tab_preprocessor.embeddings_input,\n",
" continuous_cols=continuous_cols)\n", " continuous_cols=continuous_cols\n",
" )\n",
"model = WideDeep(wide=wide, deeptabular=deeptabular)" "model = WideDeep(wide=wide, deeptabular=deeptabular)"
] ]
}, },
...@@ -305,15 +306,15 @@ ...@@ -305,15 +306,15 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 1222/1222 [00:11<00:00, 107.08it/s, loss=0.394, metrics={'acc': 0.8162}]\n", "epoch 1: 100%|██████████| 153/153 [00:04<00:00, 36.06it/s, loss=0.529, metrics={'acc': 0.7448}]\n",
"valid: 100%|██████████| 306/306 [00:01<00:00, 166.01it/s, loss=0.36, metrics={'acc': 0.8331}] \n", "valid: 100%|██████████| 39/39 [00:00<00:00, 68.26it/s, loss=0.389, metrics={'acc': 0.8176}]\n",
"epoch 2: 100%|██████████| 1222/1222 [00:11<00:00, 108.59it/s, loss=0.363, metrics={'acc': 0.829}] \n", "epoch 2: 100%|██████████| 153/153 [00:03<00:00, 39.18it/s, loss=0.401, metrics={'acc': 0.8122}]\n",
"valid: 100%|██████████| 306/306 [00:01<00:00, 166.49it/s, loss=0.353, metrics={'acc': 0.8372}]\n" "valid: 100%|██████████| 39/39 [00:00<00:00, 116.68it/s, loss=0.368, metrics={'acc': 0.8272}]\n"
] ]
} }
], ],
"source": [ "source": [
"trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target, n_epochs=2, val_split=0.2)" "trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target, n_epochs=2, val_split=0.2, batch_size=256)"
] ]
}, },
{ {
...@@ -386,8 +387,7 @@ ...@@ -386,8 +387,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "epoch 1: 3%|▎ | 5/191 [00:00<00:03, 47.72it/s, loss=0.794, metrics={'acc': 0.5348}]"
" 0%| | 0/1527 [00:00<?, ?it/s]"
] ]
}, },
{ {
...@@ -401,9 +401,9 @@ ...@@ -401,9 +401,9 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 1527/1527 [00:08<00:00, 185.49it/s, loss=0.38, metrics={'acc': 0.8248}] \n", "epoch 1: 100%|██████████| 191/191 [00:02<00:00, 67.54it/s, loss=0.504, metrics={'acc': 0.7554}]\n",
"epoch 2: 100%|██████████| 1527/1527 [00:08<00:00, 187.06it/s, loss=0.361, metrics={'acc': 0.8284}]\n", "epoch 2: 100%|██████████| 191/191 [00:02<00:00, 70.24it/s, loss=0.386, metrics={'acc': 0.79}] \n",
" 0%| | 0/1527 [00:00<?, ?it/s]" "epoch 1: 4%|▎ | 7/191 [00:00<00:03, 60.96it/s, loss=0.39, metrics={'acc': 0.7909}] "
] ]
}, },
{ {
...@@ -417,9 +417,9 @@ ...@@ -417,9 +417,9 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 1527/1527 [00:12<00:00, 126.49it/s, loss=0.367, metrics={'acc': 0.829}] \n", "epoch 1: 100%|██████████| 191/191 [00:03<00:00, 62.41it/s, loss=0.369, metrics={'acc': 0.8028}]\n",
"epoch 2: 100%|██████████| 1527/1527 [00:12<00:00, 123.37it/s, loss=0.353, metrics={'acc': 0.831}] \n", "epoch 2: 100%|██████████| 191/191 [00:03<00:00, 59.52it/s, loss=0.352, metrics={'acc': 0.8107}]\n",
" 0%| | 0/1527 [00:00<?, ?it/s]" "epoch 1: 3%|▎ | 5/191 [00:00<00:04, 43.10it/s, loss=0.363, metrics={'acc': 0.8418}]"
] ]
}, },
{ {
...@@ -433,13 +433,20 @@ ...@@ -433,13 +433,20 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 1527/1527 [00:14<00:00, 108.48it/s, loss=0.346, metrics={'acc': 0.8402}]\n", "epoch 1: 100%|██████████| 191/191 [00:04<00:00, 39.91it/s, loss=0.352, metrics={'acc': 0.8378}]\n",
"epoch 2: 100%|██████████| 1527/1527 [00:14<00:00, 107.37it/s, loss=0.34, metrics={'acc': 0.8426}] \n" "epoch 2: 100%|██████████| 191/191 [00:04<00:00, 43.80it/s, loss=0.344, metrics={'acc': 0.8419}]\n"
] ]
} }
], ],
"source": [ "source": [
"trainer_1.fit(X_wide=X_wide, X_tab=X_tab, target=target, finetune=True, finetune_epochs=2, n_epochs=2)" "trainer_1.fit(\n",
" X_wide=X_wide, \n",
" X_tab=X_tab, \n",
" target=target, \n",
" finetune=True, \n",
" finetune_epochs=2, \n",
" n_epochs=2, \n",
" batch_size=256)"
] ]
}, },
{ {
...@@ -481,8 +488,7 @@ ...@@ -481,8 +488,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "epoch 1: 3%|▎ | 6/172 [00:00<00:02, 58.53it/s, loss=0.988, metrics={'acc': 0.4435}]"
" 0%| | 0/1374 [00:00<?, ?it/s]"
] ]
}, },
{ {
...@@ -496,9 +502,9 @@ ...@@ -496,9 +502,9 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 1374/1374 [00:07<00:00, 182.23it/s, loss=0.44, metrics={'acc': 0.7869}] \n", "epoch 1: 100%|██████████| 172/172 [00:02<00:00, 73.06it/s, loss=0.54, metrics={'acc': 0.7276}] \n",
"epoch 2: 100%|██████████| 1374/1374 [00:07<00:00, 185.00it/s, loss=0.362, metrics={'acc': 0.8093}]\n", "epoch 2: 100%|██████████| 172/172 [00:02<00:00, 75.57it/s, loss=0.389, metrics={'acc': 0.7736}]\n",
" 0%| | 0/1374 [00:00<?, ?it/s]" "epoch 1: 3%|▎ | 6/172 [00:00<00:02, 55.48it/s, loss=0.582, metrics={'acc': 0.7728}]"
] ]
}, },
{ {
...@@ -512,9 +518,9 @@ ...@@ -512,9 +518,9 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 1374/1374 [00:11<00:00, 122.46it/s, loss=0.376, metrics={'acc': 0.8136}]\n", "epoch 1: 100%|██████████| 172/172 [00:02<00:00, 58.52it/s, loss=0.392, metrics={'acc': 0.7881}]\n",
"epoch 2: 100%|██████████| 1374/1374 [00:12<00:00, 112.71it/s, loss=0.352, metrics={'acc': 0.8194}]\n", "epoch 2: 100%|██████████| 172/172 [00:02<00:00, 58.26it/s, loss=0.353, metrics={'acc': 0.8}] \n",
" 0%| | 0/1374 [00:00<?, ?it/s]" "epoch 1: 2%|▏ | 4/172 [00:00<00:04, 38.87it/s, loss=0.337, metrics={'acc': 0.8589}]"
] ]
}, },
{ {
...@@ -528,15 +534,24 @@ ...@@ -528,15 +534,24 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 1374/1374 [00:13<00:00, 103.13it/s, loss=0.347, metrics={'acc': 0.8393}]\n", "epoch 1: 100%|██████████| 172/172 [00:04<00:00, 42.81it/s, loss=0.355, metrics={'acc': 0.8366}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 155.50it/s, loss=0.345, metrics={'acc': 0.8403}]\n", "valid: 100%|██████████| 20/20 [00:00<00:00, 89.21it/s, loss=0.35, metrics={'acc': 0.8356}] \n",
"epoch 2: 100%|██████████| 1374/1374 [00:13<00:00, 103.84it/s, loss=0.34, metrics={'acc': 0.8419}] \n", "epoch 2: 100%|██████████| 172/172 [00:04<00:00, 41.35it/s, loss=0.346, metrics={'acc': 0.8381}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 160.53it/s, loss=0.343, metrics={'acc': 0.8444}]\n" "valid: 100%|██████████| 20/20 [00:00<00:00, 87.63it/s, loss=0.349, metrics={'acc': 0.8373}]\n"
] ]
} }
], ],
"source": [ "source": [
"trainer_2.fit(X_wide=X_wide, X_tab=X_tab, target=target, val_split=0.1, warmup=True, warmup_epochs=2, n_epochs=2)" "trainer_2.fit(\n",
" X_wide=X_wide, \n",
" X_tab=X_tab, \n",
" target=target, \n",
" val_split=0.1, \n",
" warmup=True, \n",
" warmup_epochs=2, \n",
" n_epochs=2, \n",
" batch_size=256\n",
")"
] ]
}, },
{ {
...@@ -622,6 +637,7 @@ ...@@ -622,6 +637,7 @@
" (emb_layer_workclass): Embedding(10, 16, padding_idx=0)\n", " (emb_layer_workclass): Embedding(10, 16, padding_idx=0)\n",
" )\n", " )\n",
" (embedding_dropout): Dropout(p=0.1, inplace=False)\n", " (embedding_dropout): Dropout(p=0.1, inplace=False)\n",
" (cont_norm): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (tab_resnet_blks): DenseResnet(\n", " (tab_resnet_blks): DenseResnet(\n",
" (dense_resnet): Sequential(\n", " (dense_resnet): Sequential(\n",
" (lin1): Linear(in_features=74, out_features=128, bias=True)\n", " (lin1): Linear(in_features=74, out_features=128, bias=True)\n",
...@@ -692,20 +708,20 @@ ...@@ -692,20 +708,20 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 1374/1374 [00:23<00:00, 57.89it/s, loss=0.396, metrics={'acc': 0.8115}]\n", "epoch 1: 100%|██████████| 172/172 [00:05<00:00, 29.00it/s, loss=0.453, metrics={'acc': 0.7787}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 137.08it/s, loss=0.364, metrics={'acc': 0.8315}]\n", "valid: 100%|██████████| 20/20 [00:00<00:00, 90.03it/s, loss=0.363, metrics={'acc': 0.8282}]\n",
"epoch 2: 100%|██████████| 1374/1374 [00:23<00:00, 57.57it/s, loss=0.367, metrics={'acc': 0.8262}]\n", "epoch 2: 100%|██████████| 172/172 [00:05<00:00, 32.24it/s, loss=0.371, metrics={'acc': 0.8262}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 140.31it/s, loss=0.357, metrics={'acc': 0.8344}]\n" "valid: 100%|██████████| 20/20 [00:00<00:00, 88.22it/s, loss=0.351, metrics={'acc': 0.8356}]\n"
] ]
} }
], ],
"source": [ "source": [
"trainer_3.fit(X_wide=X_wide, X_tab=X_tab, target=target, val_split=0.1, n_epochs=2)" "trainer_3.fit(X_wide=X_wide, X_tab=X_tab, target=target, val_split=0.1, n_epochs=2, batch_size=256)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 27, "execution_count": 20,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -723,7 +739,7 @@ ...@@ -723,7 +739,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 28, "execution_count": 21,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -738,7 +754,7 @@ ...@@ -738,7 +754,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 29, "execution_count": 22,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -747,7 +763,7 @@ ...@@ -747,7 +763,7 @@
"<All keys matched successfully>" "<All keys matched successfully>"
] ]
}, },
"execution_count": 29, "execution_count": 22,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -758,7 +774,7 @@ ...@@ -758,7 +774,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 30, "execution_count": 23,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -767,12 +783,12 @@ ...@@ -767,12 +783,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 31, "execution_count": 24,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"tab_deep_layers = list(\n", "tab_deep_layers = list(\n",
" list(list(list(model_3.deeptabular.children())[0].children())[2].children())[\n", " list(list(list(model_3.deeptabular.children())[0].children())[3].children())[\n",
" 0\n", " 0\n",
" ].children()\n", " ].children()\n",
")[::-1][:2]" ")[::-1][:2]"
...@@ -780,7 +796,7 @@ ...@@ -780,7 +796,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 32, "execution_count": 25,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -789,7 +805,7 @@ ...@@ -789,7 +805,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 33, "execution_count": 26,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -822,7 +838,7 @@ ...@@ -822,7 +838,7 @@
" )]" " )]"
] ]
}, },
"execution_count": 33, "execution_count": 26,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -840,7 +856,7 @@ ...@@ -840,7 +856,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 34, "execution_count": 27,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -849,15 +865,14 @@ ...@@ -849,15 +865,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 35, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "epoch 1: 5%|▍ | 8/172 [00:00<00:02, 68.51it/s, loss=0.767, metrics={'acc': 0.5605}]"
" 0%| | 0/1374 [00:00<?, ?it/s]"
] ]
}, },
{ {
...@@ -871,9 +886,9 @@ ...@@ -871,9 +886,9 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 1374/1374 [00:06<00:00, 199.07it/s, loss=0.382, metrics={'acc': 0.8216}]\n", "epoch 1: 100%|██████████| 172/172 [00:02<00:00, 75.72it/s, loss=0.489, metrics={'acc': 0.7523}]\n",
"epoch 2: 100%|██████████| 1374/1374 [00:06<00:00, 198.04it/s, loss=0.361, metrics={'acc': 0.8268}]\n", "epoch 2: 100%|██████████| 172/172 [00:02<00:00, 64.95it/s, loss=0.383, metrics={'acc': 0.7876}]\n",
" 0%| | 0/1374 [00:00<?, ?it/s]" "epoch 1: 2%|▏ | 3/172 [00:00<00:07, 22.26it/s, loss=0.402, metrics={'acc': 0.788}] "
] ]
}, },
{ {
...@@ -887,8 +902,8 @@ ...@@ -887,8 +902,8 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 1374/1374 [00:17<00:00, 77.34it/s, loss=0.39, metrics={'acc': 0.822}] \n", "epoch 1: 100%|██████████| 172/172 [00:08<00:00, 20.71it/s, loss=0.385, metrics={'acc': 0.7986}]\n",
" 0%| | 0/1374 [00:00<?, ?it/s]" "epoch 1: 0%| | 0/172 [00:00<?, ?it/s]"
] ]
}, },
{ {
...@@ -902,8 +917,8 @@ ...@@ -902,8 +917,8 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 1374/1374 [00:19<00:00, 70.91it/s, loss=0.376, metrics={'acc': 0.8218}]\n", "epoch 1: 100%|██████████| 172/172 [00:13<00:00, 13.08it/s, loss=0.369, metrics={'acc': 0.8058}]\n",
" 0%| | 0/1374 [00:00<?, ?it/s]" "epoch 1: 2%|▏ | 3/172 [00:00<00:07, 21.56it/s, loss=0.355, metrics={'acc': 0.806}]"
] ]
}, },
{ {
...@@ -917,8 +932,8 @@ ...@@ -917,8 +932,8 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 1374/1374 [00:21<00:00, 64.08it/s, loss=0.364, metrics={'acc': 0.8234}]\n", "epoch 1: 100%|██████████| 172/172 [00:07<00:00, 22.34it/s, loss=0.361, metrics={'acc': 0.8108}]\n",
" 0%| | 0/1374 [00:00<?, ?it/s]" "epoch 1: 1%| | 2/172 [00:00<00:09, 17.34it/s, loss=0.334, metrics={'acc': 0.8581}]"
] ]
}, },
{ {
...@@ -932,10 +947,9 @@ ...@@ -932,10 +947,9 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 1374/1374 [00:23<00:00, 59.25it/s, loss=0.347, metrics={'acc': 0.8388}]\n", "epoch 1: 100%|██████████| 172/172 [00:16<00:00, 10.31it/s, loss=0.353, metrics={'acc': 0.8366}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 142.44it/s, loss=0.349, metrics={'acc': 0.8411}]\n", "valid: 100%|██████████| 20/20 [00:00<00:00, 40.53it/s, loss=0.345, metrics={'acc': 0.8405}]\n",
"epoch 2: 100%|██████████| 1374/1374 [00:22<00:00, 60.43it/s, loss=0.341, metrics={'acc': 0.8413}]\n", "epoch 2: 89%|████████▉ | 153/172 [00:06<00:00, 28.91it/s, loss=0.342, metrics={'acc': 0.8399}]"
"valid: 100%|██████████| 153/153 [00:01<00:00, 143.93it/s, loss=0.347, metrics={'acc': 0.8381}]\n"
] ]
} }
], ],
...@@ -950,7 +964,9 @@ ...@@ -950,7 +964,9 @@
" finetune_deeptabular_gradual=True,\n", " finetune_deeptabular_gradual=True,\n",
" finetune_deeptabular_layers = tab_layers,\n", " finetune_deeptabular_layers = tab_layers,\n",
" finetune_deeptabular_max_lr = 0.01,\n", " finetune_deeptabular_max_lr = 0.01,\n",
" n_epochs=2)" " n_epochs=2,\n",
" batch_size=256\n",
")"
] ]
}, },
{ {
...@@ -962,7 +978,7 @@ ...@@ -962,7 +978,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 36, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -975,7 +991,7 @@ ...@@ -975,7 +991,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 37, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -984,25 +1000,16 @@ ...@@ -984,25 +1000,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 38, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 1374/1374 [00:11<00:00, 124.80it/s, loss=0.376, metrics={'acc': 0.823}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 168.80it/s, loss=0.35, metrics={'acc': 0.8405}] \n"
]
}
],
"source": [ "source": [
"trainer_5.fit(X_wide=X_wide, X_tab=X_tab, target=target, val_split=0.1, n_epochs=1)" "trainer_5.fit(X_wide=X_wide, X_tab=X_tab, target=target, val_split=0.1, n_epochs=1, batch_size=256)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 39, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1011,7 +1018,7 @@ ...@@ -1011,7 +1018,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 40, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1024,20 +1031,9 @@ ...@@ -1024,20 +1031,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 41, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"model_5.load_state_dict(torch.load(\"models_dir/model_5.pt\"))" "model_5.load_state_dict(torch.load(\"models_dir/model_5.pt\"))"
] ]
...@@ -1051,7 +1047,7 @@ ...@@ -1051,7 +1047,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 42, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1060,40 +1056,9 @@ ...@@ -1060,40 +1056,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 43, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 0%| | 0/1374 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training deeptabular for 2 epochs\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 1374/1374 [00:11<00:00, 122.12it/s, loss=0.369, metrics={'acc': 0.8299}]\n",
"epoch 2: 100%|██████████| 1374/1374 [00:11<00:00, 118.44it/s, loss=0.352, metrics={'acc': 0.8329}]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fine-tuning finished\n"
]
}
],
"source": [ "source": [
"trainer_6.fit(\n", "trainer_6.fit(\n",
" X_wide=X_wide, \n", " X_wide=X_wide, \n",
...@@ -1103,14 +1068,15 @@ ...@@ -1103,14 +1068,15 @@
" finetune=True, \n", " finetune=True, \n",
" finetune_epochs=2,\n", " finetune_epochs=2,\n",
" finetune_max_lr=0.01,\n", " finetune_max_lr=0.01,\n",
" stop_after_finetuning=True\n", " stop_after_finetuning=True,\n",
" batch_size=256\n",
" \n", " \n",
") " ") "
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 44, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
......
...@@ -1011,7 +1011,7 @@ ...@@ -1011,7 +1011,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
" 8%|▊ | 83/1001 [00:00<00:02, 412.93it/s]" " 4%|▍ | 42/1001 [00:00<00:02, 419.42it/s]"
] ]
}, },
{ {
...@@ -1025,7 +1025,7 @@ ...@@ -1025,7 +1025,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"100%|██████████| 1001/1001 [00:02<00:00, 411.22it/s]\n" "100%|██████████| 1001/1001 [00:02<00:00, 408.24it/s]\n"
] ]
}, },
{ {
...@@ -1077,7 +1077,7 @@ ...@@ -1077,7 +1077,7 @@
" embed_input=tab_preprocessor.embeddings_input,\n", " embed_input=tab_preprocessor.embeddings_input,\n",
" embed_dropout = 0.1,\n", " embed_dropout = 0.1,\n",
" continuous_cols = continuous_cols,\n", " continuous_cols = continuous_cols,\n",
" batchnorm_cont = True\n", " cont_norm_layer = \"batchnorm\"\n",
")\n", ")\n",
" \n", " \n",
"# Pretrained Resnet 18 (default is all but last 2 conv blocks frozen) plus a FC-Head 512->256->128\n", "# Pretrained Resnet 18 (default is all but last 2 conv blocks frozen) plus a FC-Head 512->256->128\n",
...@@ -1194,8 +1194,8 @@ ...@@ -1194,8 +1194,8 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 25/25 [02:00<00:00, 4.82s/it, loss=115]\n", "epoch 1: 100%|██████████| 25/25 [02:03<00:00, 4.94s/it, loss=111]\n",
"valid: 100%|██████████| 7/7 [00:14<00:00, 2.12s/it, loss=129]\n" "valid: 100%|██████████| 7/7 [00:15<00:00, 2.17s/it, loss=94.8]\n"
] ]
} }
], ],
...@@ -1647,14 +1647,14 @@ ...@@ -1647,14 +1647,14 @@
" self.correct_count = 0\n", " self.correct_count = 0\n",
" self.total_count = 0\n", " self.total_count = 0\n",
"\n", "\n",
" # metric name needs to be defined\n", " #  metric name needs to be defined\n",
" self._name = \"acc\"\n", " self._name = \"acc\"\n",
"\n", "\n",
" def reset(self):\n", " def reset(self):\n",
" self.correct_count = 0\n", " self.correct_count = 0\n",
" self.total_count = 0\n", " self.total_count = 0\n",
"\n", "\n",
" def __call__(self, y_pred: Tensor, y_true: Tensor) -> float:\n", " def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:\n",
" num_classes = y_pred.size(1)\n", " num_classes = y_pred.size(1)\n",
"\n", "\n",
" if num_classes == 1:\n", " if num_classes == 1:\n",
...@@ -1667,7 +1667,7 @@ ...@@ -1667,7 +1667,7 @@
" self.correct_count += y_pred.eq(y_true).sum().item()\n", " self.correct_count += y_pred.eq(y_true).sum().item()\n",
" self.total_count += len(y_pred)\n", " self.total_count += len(y_pred)\n",
" accuracy = float(self.correct_count) / float(self.total_count)\n", " accuracy = float(self.correct_count) / float(self.total_count)\n",
" return accuracy" " return np.array(accuracy)\n"
] ]
}, },
{ {
...@@ -1691,7 +1691,7 @@ ...@@ -1691,7 +1691,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 24, "execution_count": 21,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1719,7 +1719,7 @@ ...@@ -1719,7 +1719,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 25, "execution_count": 22,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1728,23 +1728,23 @@ ...@@ -1728,23 +1728,23 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 26, "execution_count": 23,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 611/611 [00:06<00:00, 98.59it/s, loss=0.375, metrics={'acc': 0.8236}] \n", "epoch 1: 100%|██████████| 611/611 [00:06<00:00, 90.27it/s, loss=0.455, metrics={'acc': 0.7866}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 154.88it/s, loss=0.361, metrics={'acc': 0.8327}]\n", "valid: 100%|██████████| 153/153 [00:00<00:00, 155.16it/s, loss=0.367, metrics={'acc': 0.8318}]\n",
"epoch 2: 100%|██████████| 611/611 [00:06<00:00, 95.06it/s, loss=0.365, metrics={'acc': 0.8284}] \n", "epoch 2: 100%|██████████| 611/611 [00:06<00:00, 89.76it/s, loss=0.379, metrics={'acc': 0.8233}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 150.72it/s, loss=0.354, metrics={'acc': 0.838}] \n", "valid: 100%|██████████| 153/153 [00:01<00:00, 151.92it/s, loss=0.354, metrics={'acc': 0.8362}]\n",
"epoch 3: 100%|██████████| 611/611 [00:06<00:00, 98.56it/s, loss=0.358, metrics={'acc': 0.833}] \n", "epoch 3: 100%|██████████| 611/611 [00:06<00:00, 89.63it/s, loss=0.365, metrics={'acc': 0.8296}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 159.22it/s, loss=0.353, metrics={'acc': 0.8389}]\n", "valid: 100%|██████████| 153/153 [00:00<00:00, 154.56it/s, loss=0.35, metrics={'acc': 0.8383}] \n",
"epoch 4: 100%|██████████| 611/611 [00:06<00:00, 99.64it/s, loss=0.354, metrics={'acc': 0.833}] \n", "epoch 4: 100%|██████████| 611/611 [00:06<00:00, 87.99it/s, loss=0.357, metrics={'acc': 0.8333}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 155.01it/s, loss=0.35, metrics={'acc': 0.8413}] \n", "valid: 100%|██████████| 153/153 [00:00<00:00, 155.60it/s, loss=0.348, metrics={'acc': 0.8406}]\n",
"epoch 5: 100%|██████████| 611/611 [00:06<00:00, 99.59it/s, loss=0.351, metrics={'acc': 0.8355}] \n", "epoch 5: 100%|██████████| 611/611 [00:07<00:00, 87.18it/s, loss=0.354, metrics={'acc': 0.835}] \n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 137.04it/s, loss=0.348, metrics={'acc': 0.8405}]\n" "valid: 100%|██████████| 153/153 [00:01<00:00, 149.32it/s, loss=0.346, metrics={'acc': 0.8402}]\n"
] ]
} }
], ],
...@@ -1754,7 +1754,7 @@ ...@@ -1754,7 +1754,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 27, "execution_count": 24,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -1763,7 +1763,7 @@ ...@@ -1763,7 +1763,7 @@
"{'beginning': [1, 2, 3, 4, 5], 'end': [1, 2, 3, 4, 5]}" "{'beginning': [1, 2, 3, 4, 5], 'end': [1, 2, 3, 4, 5]}"
] ]
}, },
"execution_count": 27, "execution_count": 24,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
......
...@@ -11,9 +11,18 @@ ...@@ -11,9 +11,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/javier/.pyenv/versions/3.7.7/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n"
]
}
],
"source": [ "source": [
"import pickle\n", "import pickle\n",
"import numpy as np\n", "import numpy as np\n",
...@@ -31,7 +40,7 @@ ...@@ -31,7 +40,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -190,7 +199,7 @@ ...@@ -190,7 +199,7 @@
"4 30 United-States <=50K " "4 30 United-States <=50K "
] ]
}, },
"execution_count": 4, "execution_count": 2,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -202,7 +211,7 @@ ...@@ -202,7 +211,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -361,7 +370,7 @@ ...@@ -361,7 +370,7 @@
"4 30 United-States 0 " "4 30 United-States 0 "
] ]
}, },
"execution_count": 5, "execution_count": 3,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -377,7 +386,7 @@ ...@@ -377,7 +386,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -388,7 +397,7 @@ ...@@ -388,7 +397,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -409,7 +418,7 @@ ...@@ -409,7 +418,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -419,7 +428,7 @@ ...@@ -419,7 +428,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -432,7 +441,7 @@ ...@@ -432,7 +441,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -443,7 +452,7 @@ ...@@ -443,7 +452,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -456,7 +465,7 @@ ...@@ -456,7 +465,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -471,7 +480,7 @@ ...@@ -471,7 +480,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 11,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -479,20 +488,20 @@ ...@@ -479,20 +488,20 @@
"text/plain": [ "text/plain": [
"WideDeep(\n", "WideDeep(\n",
" (wide): Wide(\n", " (wide): Wide(\n",
" (wide_linear): Embedding(775, 1, padding_idx=0)\n", " (wide_linear): Embedding(773, 1, padding_idx=0)\n",
" )\n", " )\n",
" (deeptabular): Sequential(\n", " (deeptabular): Sequential(\n",
" (0): TabMlp(\n", " (0): TabMlp(\n",
" (embed_layers): ModuleDict(\n", " (embed_layers): ModuleDict(\n",
" (emb_layer_age): Embedding(75, 18, padding_idx=0)\n", " (emb_layer_age): Embedding(75, 18, padding_idx=0)\n",
" (emb_layer_capital_gain): Embedding(122, 23, padding_idx=0)\n", " (emb_layer_capital_gain): Embedding(121, 23, padding_idx=0)\n",
" (emb_layer_capital_loss): Embedding(98, 21, padding_idx=0)\n", " (emb_layer_capital_loss): Embedding(98, 21, padding_idx=0)\n",
" (emb_layer_education): Embedding(17, 8, padding_idx=0)\n", " (emb_layer_education): Embedding(17, 8, padding_idx=0)\n",
" (emb_layer_educational_num): Embedding(17, 8, padding_idx=0)\n", " (emb_layer_educational_num): Embedding(17, 8, padding_idx=0)\n",
" (emb_layer_gender): Embedding(3, 2, padding_idx=0)\n", " (emb_layer_gender): Embedding(3, 2, padding_idx=0)\n",
" (emb_layer_hours_per_week): Embedding(97, 21, padding_idx=0)\n", " (emb_layer_hours_per_week): Embedding(96, 20, padding_idx=0)\n",
" (emb_layer_marital_status): Embedding(8, 5, padding_idx=0)\n", " (emb_layer_marital_status): Embedding(8, 5, padding_idx=0)\n",
" (emb_layer_native_country): Embedding(42, 13, padding_idx=0)\n", " (emb_layer_native_country): Embedding(43, 13, padding_idx=0)\n",
" (emb_layer_occupation): Embedding(16, 7, padding_idx=0)\n", " (emb_layer_occupation): Embedding(16, 7, padding_idx=0)\n",
" (emb_layer_race): Embedding(6, 4, padding_idx=0)\n", " (emb_layer_race): Embedding(6, 4, padding_idx=0)\n",
" (emb_layer_relationship): Embedding(7, 4, padding_idx=0)\n", " (emb_layer_relationship): Embedding(7, 4, padding_idx=0)\n",
...@@ -503,7 +512,7 @@ ...@@ -503,7 +512,7 @@
" (mlp): Sequential(\n", " (mlp): Sequential(\n",
" (dense_layer_0): Sequential(\n", " (dense_layer_0): Sequential(\n",
" (0): Dropout(p=0.1, inplace=False)\n", " (0): Dropout(p=0.1, inplace=False)\n",
" (1): Linear(in_features=139, out_features=200, bias=True)\n", " (1): Linear(in_features=138, out_features=200, bias=True)\n",
" (2): ReLU(inplace=True)\n", " (2): ReLU(inplace=True)\n",
" )\n", " )\n",
" (dense_layer_1): Sequential(\n", " (dense_layer_1): Sequential(\n",
...@@ -519,7 +528,7 @@ ...@@ -519,7 +528,7 @@
")" ")"
] ]
}, },
"execution_count": 13, "execution_count": 11,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -530,16 +539,16 @@ ...@@ -530,16 +539,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 12,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 153/153 [00:04<00:00, 30.90it/s, loss=0.456, metrics={'acc': 0.7925}]\n", "epoch 1: 100%|██████████| 153/153 [00:04<00:00, 31.06it/s, loss=0.479, metrics={'acc': 0.7839}]\n",
"valid: 100%|██████████| 20/20 [00:00<00:00, 89.84it/s, loss=0.384, metrics={'acc': 0.8313}]\n", "valid: 100%|██████████| 20/20 [00:00<00:00, 48.70it/s, loss=0.348, metrics={'acc': 0.8444}]\n",
" 0%| | 0/153 [00:00<?, ?it/s]" "epoch 2: 3%|▎ | 4/153 [00:00<00:04, 32.06it/s, loss=0.365, metrics={'acc': 0.8385}]"
] ]
}, },
{ {
...@@ -547,15 +556,15 @@ ...@@ -547,15 +556,15 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\n", "\n",
"Epoch 00001: val_loss improved from inf to 0.38447, saving model to tmp_dir/adult_tabmlp_model_1.p\n" "Epoch 00001: val_loss improved from inf to 0.34800, saving model to tmp_dir/adult_tabmlp_model_1.p\n"
] ]
}, },
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 2: 100%|██████████| 153/153 [00:03<00:00, 38.97it/s, loss=0.364, metrics={'acc': 0.8343}]\n", "epoch 2: 100%|██████████| 153/153 [00:04<00:00, 32.33it/s, loss=0.354, metrics={'acc': 0.8379}]\n",
"valid: 100%|██████████| 20/20 [00:00<00:00, 95.78it/s, loss=0.347, metrics={'acc': 0.8366}] " "valid: 100%|██████████| 20/20 [00:00<00:00, 92.91it/s, loss=0.322, metrics={'acc': 0.8511}] "
] ]
}, },
{ {
...@@ -563,7 +572,7 @@ ...@@ -563,7 +572,7 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\n", "\n",
"Epoch 00002: val_loss improved from 0.38447 to 0.34708, saving model to tmp_dir/adult_tabmlp_model_2.p\n", "Epoch 00002: val_loss improved from 0.34800 to 0.32204, saving model to tmp_dir/adult_tabmlp_model_2.p\n",
"Model weights restored to best epoch: 2\n" "Model weights restored to best epoch: 2\n"
] ]
}, },
...@@ -610,7 +619,7 @@ ...@@ -610,7 +619,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 13,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -619,7 +628,7 @@ ...@@ -619,7 +628,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 14,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -637,7 +646,7 @@ ...@@ -637,7 +646,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 15,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -653,7 +662,7 @@ ...@@ -653,7 +662,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 16,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -662,7 +671,7 @@ ...@@ -662,7 +671,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 17,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -686,7 +695,7 @@ ...@@ -686,7 +695,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": 18,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -721,7 +730,7 @@ ...@@ -721,7 +730,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 21, "execution_count": 19,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -731,7 +740,7 @@ ...@@ -731,7 +740,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22, "execution_count": 20,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -741,7 +750,7 @@ ...@@ -741,7 +750,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 23, "execution_count": 21,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -751,7 +760,7 @@ ...@@ -751,7 +760,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 24, "execution_count": 22,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -792,7 +801,7 @@ ...@@ -792,7 +801,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 25, "execution_count": 23,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -835,94 +844,94 @@ ...@@ -835,94 +844,94 @@
" </thead>\n", " </thead>\n",
" <tbody>\n", " <tbody>\n",
" <tr>\n", " <tr>\n",
" <th>42693</th>\n", " <th>32823</th>\n",
" <td>19</td>\n", " <td>33</td>\n",
" <td>Private</td>\n", " <td>Private</td>\n",
" <td>85690</td>\n", " <td>201988</td>\n",
" <td>HS-grad</td>\n", " <td>Masters</td>\n",
" <td>9</td>\n", " <td>14</td>\n",
" <td>Never-married</td>\n", " <td>Married-civ-spouse</td>\n",
" <td>Machine-op-inspct</td>\n", " <td>Prof-specialty</td>\n",
" <td>Unmarried</td>\n", " <td>Husband</td>\n",
" <td>White</td>\n", " <td>White</td>\n",
" <td>Male</td>\n", " <td>Male</td>\n",
" <td>0</td>\n", " <td>0</td>\n",
" <td>0</td>\n", " <td>0</td>\n",
" <td>30</td>\n", " <td>45</td>\n",
" <td>United-States</td>\n", " <td>United-States</td>\n",
" <td>0</td>\n", " <td>0</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <th>20340</th>\n", " <th>40713</th>\n",
" <td>39</td>\n", " <td>31</td>\n",
" <td>Private</td>\n", " <td>Private</td>\n",
" <td>63910</td>\n", " <td>231826</td>\n",
" <td>Some-college</td>\n", " <td>HS-grad</td>\n",
" <td>10</td>\n", " <td>9</td>\n",
" <td>Never-married</td>\n", " <td>Married-civ-spouse</td>\n",
" <td>Adm-clerical</td>\n", " <td>Other-service</td>\n",
" <td>Own-child</td>\n", " <td>Husband</td>\n",
" <td>Asian-Pac-Islander</td>\n", " <td>White</td>\n",
" <td>Female</td>\n", " <td>Male</td>\n",
" <td>0</td>\n", " <td>0</td>\n",
" <td>0</td>\n", " <td>0</td>\n",
" <td>40</td>\n", " <td>52</td>\n",
" <td>United-States</td>\n", " <td>Mexico</td>\n",
" <td>0</td>\n", " <td>0</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <th>31992</th>\n", " <th>16020</th>\n",
" <td>27</td>\n", " <td>38</td>\n",
" <td>Private</td>\n", " <td>Private</td>\n",
" <td>166210</td>\n", " <td>24126</td>\n",
" <td>HS-grad</td>\n", " <td>Some-college</td>\n",
" <td>9</td>\n", " <td>10</td>\n",
" <td>Divorced</td>\n", " <td>Divorced</td>\n",
" <td>Craft-repair</td>\n", " <td>Exec-managerial</td>\n",
" <td>Not-in-family</td>\n", " <td>Not-in-family</td>\n",
" <td>White</td>\n", " <td>White</td>\n",
" <td>Male</td>\n", " <td>Female</td>\n",
" <td>0</td>\n", " <td>0</td>\n",
" <td>0</td>\n", " <td>0</td>\n",
" <td>50</td>\n", " <td>40</td>\n",
" <td>United-States</td>\n", " <td>United-States</td>\n",
" <td>0</td>\n", " <td>0</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <th>44560</th>\n", " <th>32766</th>\n",
" <td>34</td>\n", " <td>38</td>\n",
" <td>Federal-gov</td>\n", " <td>State-gov</td>\n",
" <td>284703</td>\n", " <td>312528</td>\n",
" <td>Some-college</td>\n", " <td>Bachelors</td>\n",
" <td>10</td>\n", " <td>13</td>\n",
" <td>Married-civ-spouse</td>\n", " <td>Married-civ-spouse</td>\n",
" <td>Machine-op-inspct</td>\n", " <td>Exec-managerial</td>\n",
" <td>Husband</td>\n", " <td>Husband</td>\n",
" <td>Black</td>\n", " <td>White</td>\n",
" <td>Male</td>\n", " <td>Male</td>\n",
" <td>0</td>\n", " <td>0</td>\n",
" <td>0</td>\n", " <td>0</td>\n",
" <td>52</td>\n", " <td>37</td>\n",
" <td>United-States</td>\n", " <td>United-States</td>\n",
" <td>0</td>\n", " <td>0</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <th>31965</th>\n", " <th>9713</th>\n",
" <td>60</td>\n", " <td>40</td>\n",
" <td>Self-emp-not-inc</td>\n", " <td>Self-emp-not-inc</td>\n",
" <td>73091</td>\n", " <td>121012</td>\n",
" <td>HS-grad</td>\n", " <td>Prof-school</td>\n",
" <td>9</td>\n", " <td>15</td>\n",
" <td>Separated</td>\n", " <td>Married-civ-spouse</td>\n",
" <td>Other-service</td>\n", " <td>Prof-specialty</td>\n",
" <td>Not-in-family</td>\n", " <td>Husband</td>\n",
" <td>Black</td>\n", " <td>White</td>\n",
" <td>Male</td>\n", " <td>Male</td>\n",
" <td>0</td>\n", " <td>0</td>\n",
" <td>1876</td>\n", " <td>1977</td>\n",
" <td>50</td>\n", " <td>50</td>\n",
" <td>United-States</td>\n", " <td>United-States</td>\n",
" <td>0</td>\n", " <td>1</td>\n",
" </tr>\n", " </tr>\n",
" </tbody>\n", " </tbody>\n",
"</table>\n", "</table>\n",
...@@ -930,35 +939,28 @@ ...@@ -930,35 +939,28 @@
], ],
"text/plain": [ "text/plain": [
" age workclass fnlwgt education educational_num \\\n", " age workclass fnlwgt education educational_num \\\n",
"42693 19 Private 85690 HS-grad 9 \n", "32823 33 Private 201988 Masters 14 \n",
"20340 39 Private 63910 Some-college 10 \n", "40713 31 Private 231826 HS-grad 9 \n",
"31992 27 Private 166210 HS-grad 9 \n", "16020 38 Private 24126 Some-college 10 \n",
"44560 34 Federal-gov 284703 Some-college 10 \n", "32766 38 State-gov 312528 Bachelors 13 \n",
"31965 60 Self-emp-not-inc 73091 HS-grad 9 \n", "9713 40 Self-emp-not-inc 121012 Prof-school 15 \n",
"\n", "\n",
" marital_status occupation relationship \\\n", " marital_status occupation relationship race gender \\\n",
"42693 Never-married Machine-op-inspct Unmarried \n", "32823 Married-civ-spouse Prof-specialty Husband White Male \n",
"20340 Never-married Adm-clerical Own-child \n", "40713 Married-civ-spouse Other-service Husband White Male \n",
"31992 Divorced Craft-repair Not-in-family \n", "16020 Divorced Exec-managerial Not-in-family White Female \n",
"44560 Married-civ-spouse Machine-op-inspct Husband \n", "32766 Married-civ-spouse Exec-managerial Husband White Male \n",
"31965 Separated Other-service Not-in-family \n", "9713 Married-civ-spouse Prof-specialty Husband White Male \n",
"\n", "\n",
" race gender capital_gain capital_loss hours_per_week \\\n", " capital_gain capital_loss hours_per_week native_country target \n",
"42693 White Male 0 0 30 \n", "32823 0 0 45 United-States 0 \n",
"20340 Asian-Pac-Islander Female 0 0 40 \n", "40713 0 0 52 Mexico 0 \n",
"31992 White Male 0 0 50 \n", "16020 0 0 40 United-States 0 \n",
"44560 Black Male 0 0 52 \n", "32766 0 0 37 United-States 0 \n",
"31965 Black Male 0 1876 50 \n", "9713 0 1977 50 United-States 1 "
"\n",
" native_country target \n",
"42693 United-States 0 \n",
"20340 United-States 0 \n",
"31992 United-States 0 \n",
"44560 United-States 0 \n",
"31965 United-States 0 "
] ]
}, },
"execution_count": 25, "execution_count": 23,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -969,7 +971,7 @@ ...@@ -969,7 +971,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 26, "execution_count": 24,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -979,7 +981,7 @@ ...@@ -979,7 +981,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 27, "execution_count": 25,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -989,7 +991,7 @@ ...@@ -989,7 +991,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 29, "execution_count": 26,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1000,7 +1002,7 @@ ...@@ -1000,7 +1002,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 30, "execution_count": 27,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1014,7 +1016,7 @@ ...@@ -1014,7 +1016,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 31, "execution_count": 28,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -1023,7 +1025,7 @@ ...@@ -1023,7 +1025,7 @@
"<All keys matched successfully>" "<All keys matched successfully>"
] ]
}, },
"execution_count": 31, "execution_count": 28,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -1034,7 +1036,7 @@ ...@@ -1034,7 +1036,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 32, "execution_count": 29,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1046,14 +1048,14 @@ ...@@ -1046,14 +1048,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 33, "execution_count": 30,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"predict: 100%|██████████| 20/20 [00:00<00:00, 119.92it/s]\n" "predict: 100%|██████████| 20/20 [00:00<00:00, 86.04it/s]\n"
] ]
} }
], ],
...@@ -1063,7 +1065,7 @@ ...@@ -1063,7 +1065,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 34, "execution_count": 31,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1072,16 +1074,16 @@ ...@@ -1072,16 +1074,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 35, "execution_count": 32,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"0.8517911975435005" "0.8554759467758444"
] ]
}, },
"execution_count": 35, "execution_count": 32,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -1092,7 +1094,7 @@ ...@@ -1092,7 +1094,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 36, "execution_count": 33,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
......
...@@ -75,7 +75,7 @@ ...@@ -75,7 +75,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -634,7 +634,7 @@ ...@@ -634,7 +634,7 @@
"4 0.68 -0.59 2.0 -36.0 -6.9 2.02 0.14 -0.23 " "4 0.68 -0.59 2.0 -36.0 -6.9 2.02 0.14 -0.23 "
] ]
}, },
"execution_count": 3, "execution_count": 2,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -647,7 +647,7 @@ ...@@ -647,7 +647,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -658,7 +658,7 @@ ...@@ -658,7 +658,7 @@
"Name: target, dtype: int64" "Name: target, dtype: int64"
] ]
}, },
"execution_count": 4, "execution_count": 3,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -670,7 +670,7 @@ ...@@ -670,7 +670,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -680,7 +680,7 @@ ...@@ -680,7 +680,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -697,7 +697,7 @@ ...@@ -697,7 +697,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -706,7 +706,7 @@ ...@@ -706,7 +706,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -732,7 +732,7 @@ ...@@ -732,7 +732,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -743,7 +743,7 @@ ...@@ -743,7 +743,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -752,6 +752,7 @@ ...@@ -752,6 +752,7 @@
"WideDeep(\n", "WideDeep(\n",
" (deeptabular): Sequential(\n", " (deeptabular): Sequential(\n",
" (0): TabMlp(\n", " (0): TabMlp(\n",
" (cont_norm): BatchNorm1d(74, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (tab_mlp): MLP(\n", " (tab_mlp): MLP(\n",
" (mlp): Sequential(\n", " (mlp): Sequential(\n",
" (dense_layer_0): Sequential(\n", " (dense_layer_0): Sequential(\n",
...@@ -787,7 +788,7 @@ ...@@ -787,7 +788,7 @@
")" ")"
] ]
}, },
"execution_count": 10, "execution_count": 9,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -802,7 +803,7 @@ ...@@ -802,7 +803,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -860,12 +861,12 @@ ...@@ -860,12 +861,12 @@
"Consider using one of the following signatures instead:\n", "Consider using one of the following signatures instead:\n",
"\tnonzero(Tensor input, *, bool as_tuple) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:766.)\n", "\tnonzero(Tensor input, *, bool as_tuple) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:766.)\n",
" meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu()\n", " meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu()\n",
"epoch 1: 100%|██████████| 208/208 [00:02<00:00, 78.94it/s, loss=0.188, metrics={'Accuracy': [0.9249, 0.9249], 'Precision': 0.9249, 'Recall': [0.9249, 0.9249], 'F1': [0.9244, 0.9253]}] \n", "epoch 1: 100%|██████████| 208/208 [00:02<00:00, 76.68it/s, loss=0.225, metrics={'Accuracy': [0.927, 0.8861], 'Precision': 0.9064, 'Recall': [0.927, 0.8861], 'F1': [0.9075, 0.9052]}] \n",
"valid: 100%|██████████| 292/292 [00:02<00:00, 107.92it/s, loss=0.0857, metrics={'Accuracy': [0.9664, 0.9147], 'Precision': 0.9659, 'Recall': [0.9664, 0.9147], 'F1': [0.9825, 0.322]}] \n", "valid: 100%|██████████| 292/292 [00:02<00:00, 107.00it/s, loss=0.104, metrics={'Accuracy': [0.9626, 0.875], 'Precision': 0.9618, 'Recall': [0.9626, 0.875], 'F1': [0.9804, 0.2886]}] \n",
"epoch 2: 100%|██████████| 208/208 [00:02<00:00, 88.97it/s, loss=0.121, metrics={'Accuracy': [0.9521, 0.9491], 'Precision': 0.9506, 'Recall': [0.9521, 0.9491], 'F1': [0.9512, 0.95]}] \n", "epoch 2: 100%|██████████| 208/208 [00:02<00:00, 84.52it/s, loss=0.152, metrics={'Accuracy': [0.9471, 0.9298], 'Precision': 0.9384, 'Recall': [0.9471, 0.9298], 'F1': [0.9384, 0.9383]}]\n",
"valid: 100%|██████████| 292/292 [00:02<00:00, 109.12it/s, loss=0.0894, metrics={'Accuracy': [0.9613, 0.9302], 'Precision': 0.961, 'Recall': [0.9613, 0.9302], 'F1': [0.98, 0.297]}] \n", "valid: 100%|██████████| 292/292 [00:02<00:00, 107.06it/s, loss=0.0915, metrics={'Accuracy': [0.968, 0.8906], 'Precision': 0.9673, 'Recall': [0.968, 0.8906], 'F1': [0.9833, 0.3258]}] \n",
"epoch 3: 100%|██████████| 208/208 [00:02<00:00, 86.76it/s, loss=0.102, metrics={'Accuracy': [0.9534, 0.9631], 'Precision': 0.9583, 'Recall': [0.9534, 0.9631], 'F1': [0.9576, 0.9591]}]\n", "epoch 3: 100%|██████████| 208/208 [00:02<00:00, 81.70it/s, loss=0.134, metrics={'Accuracy': [0.949, 0.9407], 'Precision': 0.9448, 'Recall': [0.949, 0.9407], 'F1': [0.9446, 0.9451]}] \n",
"valid: 100%|██████████| 292/292 [00:02<00:00, 106.86it/s, loss=0.116, metrics={'Accuracy': [0.9437, 0.9457], 'Precision': 0.9437, 'Recall': [0.9437, 0.9457], 'F1': [0.9708, 0.2293]}]\n" "valid: 100%|██████████| 292/292 [00:02<00:00, 104.52it/s, loss=0.0847, metrics={'Accuracy': [0.9679, 0.8828], 'Precision': 0.9672, 'Recall': [0.9679, 0.8828], 'F1': [0.9832, 0.3229]}]\n"
] ]
}, },
{ {
...@@ -929,67 +930,72 @@ ...@@ -929,67 +930,72 @@
" <tbody>\n", " <tbody>\n",
" <tr>\n", " <tr>\n",
" <th>0</th>\n", " <th>0</th>\n",
" <td>0.187643</td>\n", " <td>0.225108</td>\n",
" <td>[0.92486894, 0.9248898]</td>\n", " <td>[0.9270001649856567, 0.8861073851585388]</td>\n",
" <td>0.92487943</td>\n", " <td>0.906365</td>\n",
" <td>[0.92486894, 0.9248898]</td>\n", " <td>[0.9270001649856567, 0.8861073851585388]</td>\n",
" <td>[0.9244203, 0.92533314]</td>\n", " <td>[0.9074797034263611, 0.9052220582962036]</td>\n",
" <td>0.085667</td>\n", " <td>0.103954</td>\n",
" <td>[0.96635747, 0.9147287]</td>\n", " <td>[0.962550163269043, 0.875]</td>\n",
" <td>0.96590054</td>\n", " <td>0.961784</td>\n",
" <td>[0.96635747, 0.9147287]</td>\n", " <td>[0.962550163269043, 0.875]</td>\n",
" <td>[0.98251045, 0.32196453]</td>\n", " <td>[0.9803645014762878, 0.28863346576690674]</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <th>1</th>\n", " <th>1</th>\n",
" <td>0.121485</td>\n", " <td>0.152386</td>\n",
" <td>[0.9521358, 0.9490831]</td>\n", " <td>[0.9471125602722168, 0.9297876358032227]</td>\n",
" <td>0.9506268</td>\n", " <td>0.938380</td>\n",
" <td>[0.9521358, 0.9490831]</td>\n", " <td>[0.9471125602722168, 0.9297876358032227]</td>\n",
" <td>[0.95122886, 0.95000976]</td>\n", " <td>[0.9384452104568481, 0.9383144974708557]</td>\n",
" <td>0.089378</td>\n", " <td>0.091541</td>\n",
" <td>[0.9613042, 0.9302326]</td>\n", " <td>[0.9680188298225403, 0.890625]</td>\n",
" <td>0.9610292</td>\n", " <td>0.967341</td>\n",
" <td>[0.9613042, 0.9302326]</td>\n", " <td>[0.9680188298225403, 0.890625]</td>\n",
" <td>[0.9799591, 0.2970297]</td>\n", " <td>[0.9832653999328613, 0.3257790505886078]</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <th>2</th>\n", " <th>2</th>\n",
" <td>0.102038</td>\n", " <td>0.134425</td>\n",
" <td>[0.9534429, 0.96310383]</td>\n", " <td>[0.9490371346473694, 0.9407152533531189]</td>\n",
" <td>0.95834136</td>\n", " <td>0.944841</td>\n",
" <td>[0.9534429, 0.96310383]</td>\n", " <td>[0.9490371346473694, 0.9407152533531189]</td>\n",
" <td>[0.9575638, 0.95909095]</td>\n", " <td>[0.9446273446083069, 0.9450528621673584]</td>\n",
" <td>0.115516</td>\n", " <td>0.084717</td>\n",
" <td>[0.9437215, 0.9457364]</td>\n", " <td>[0.967949628829956, 0.8828125]</td>\n",
" <td>0.9437393</td>\n", " <td>0.967204</td>\n",
" <td>[0.9437215, 0.9457364]</td>\n", " <td>[0.967949628829956, 0.8828125]</td>\n",
" <td>[0.97080404, 0.22932333]</td>\n", " <td>[0.9831950664520264, 0.3229461908340454]</td>\n",
" </tr>\n", " </tr>\n",
" </tbody>\n", " </tbody>\n",
"</table>\n", "</table>\n",
"</div>" "</div>"
], ],
"text/plain": [ "text/plain": [
" train_loss train_Accuracy train_Precision \\\n", " train_loss train_Accuracy train_Precision \\\n",
"0 0.187643 [0.92486894, 0.9248898] 0.92487943 \n", "0 0.225108 [0.9270001649856567, 0.8861073851585388] 0.906365 \n",
"1 0.121485 [0.9521358, 0.9490831] 0.9506268 \n", "1 0.152386 [0.9471125602722168, 0.9297876358032227] 0.938380 \n",
"2 0.102038 [0.9534429, 0.96310383] 0.95834136 \n", "2 0.134425 [0.9490371346473694, 0.9407152533531189] 0.944841 \n",
"\n",
" train_Recall \\\n",
"0 [0.9270001649856567, 0.8861073851585388] \n",
"1 [0.9471125602722168, 0.9297876358032227] \n",
"2 [0.9490371346473694, 0.9407152533531189] \n",
"\n", "\n",
" train_Recall train_F1 val_loss \\\n", " train_F1 val_loss \\\n",
"0 [0.92486894, 0.9248898] [0.9244203, 0.92533314] 0.085667 \n", "0 [0.9074797034263611, 0.9052220582962036] 0.103954 \n",
"1 [0.9521358, 0.9490831] [0.95122886, 0.95000976] 0.089378 \n", "1 [0.9384452104568481, 0.9383144974708557] 0.091541 \n",
"2 [0.9534429, 0.96310383] [0.9575638, 0.95909095] 0.115516 \n", "2 [0.9446273446083069, 0.9450528621673584] 0.084717 \n",
"\n", "\n",
" val_Accuracy val_Precision val_Recall \\\n", " val_Accuracy val_Precision \\\n",
"0 [0.96635747, 0.9147287] 0.96590054 [0.96635747, 0.9147287] \n", "0 [0.962550163269043, 0.875] 0.961784 \n",
"1 [0.9613042, 0.9302326] 0.9610292 [0.9613042, 0.9302326] \n", "1 [0.9680188298225403, 0.890625] 0.967341 \n",
"2 [0.9437215, 0.9457364] 0.9437393 [0.9437215, 0.9457364] \n", "2 [0.967949628829956, 0.8828125] 0.967204 \n",
"\n", "\n",
" val_F1 \n", " val_Recall val_F1 \n",
"0 [0.98251045, 0.32196453] \n", "0 [0.962550163269043, 0.875] [0.9803645014762878, 0.28863346576690674] \n",
"1 [0.9799591, 0.2970297] \n", "1 [0.9680188298225403, 0.890625] [0.9832653999328613, 0.3257790505886078] \n",
"2 [0.97080404, 0.22932333] " "2 [0.967949628829956, 0.8828125] [0.9831950664520264, 0.3229461908340454] "
] ]
}, },
"execution_count": 14, "execution_count": 14,
...@@ -1010,7 +1016,7 @@ ...@@ -1010,7 +1016,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"predict: 100%|██████████| 292/292 [00:00<00:00, 306.28it/s]\n" "predict: 100%|██████████| 292/292 [00:00<00:00, 294.65it/s]\n"
] ]
}, },
{ {
...@@ -1019,15 +1025,15 @@ ...@@ -1019,15 +1025,15 @@
"text": [ "text": [
" precision recall f1-score support\n", " precision recall f1-score support\n",
"\n", "\n",
" 0 1.00 0.94 0.97 14446\n", " 0 1.00 0.97 0.99 14446\n",
" 1 0.13 0.95 0.23 130\n", " 1 0.23 0.93 0.36 130\n",
"\n", "\n",
" accuracy 0.94 14576\n", " accuracy 0.97 14576\n",
" macro avg 0.57 0.95 0.60 14576\n", " macro avg 0.61 0.95 0.67 14576\n",
"weighted avg 0.99 0.94 0.96 14576\n", "weighted avg 0.99 0.97 0.98 14576\n",
"\n", "\n",
"Actual predicted values:\n", "Actual predicted values:\n",
"(array([0, 1]), array([13650, 926]))\n" "(array([0, 1]), array([14039, 537]))\n"
] ]
} }
], ],
......
此差异已折叠。
此差异已折叠。
...@@ -91,6 +91,6 @@ if __name__ == "__main__": ...@@ -91,6 +91,6 @@ if __name__ == "__main__":
X_tab=X_tab, X_tab=X_tab,
target=target, target=target,
n_epochs=2, n_epochs=2,
batch_size=128, batch_size=512,
val_split=0.2, val_split=0.2,
) )
...@@ -17,6 +17,8 @@ using wide and deep models. ...@@ -17,6 +17,8 @@ using wide and deep models.
**Experiments and comparisson with `LightGBM`**: [TabularDL vs LightGBM](https://github.com/jrzaurin/tabulardl-benchmark) **Experiments and comparisson with `LightGBM`**: [TabularDL vs LightGBM](https://github.com/jrzaurin/tabulardl-benchmark)
**Slack**: if you want to contribute or just want to chat with us, join [slack](https://join.slack.com/t/pytorch-widedeep/shared_invite/zt-soss7stf-iXpVuLeKZz8lGTnxxtHtTw)
### Introduction ### Introduction
``pytorch-widedeep`` is based on Google's [Wide and Deep Algorithm](https://arxiv.org/abs/1606.07792) ``pytorch-widedeep`` is based on Google's [Wide and Deep Algorithm](https://arxiv.org/abs/1606.07792)
...@@ -55,20 +57,20 @@ cd pytorch-widedeep ...@@ -55,20 +57,20 @@ cd pytorch-widedeep
pip install -e . pip install -e .
``` ```
**Important note for Mac users**: at the time of writing (June-2021) the **Important note for Mac users**: at the time of writing the latest `torch`
latest `torch` release is `1.9`. Some past release is `1.9`. Some past [issues](https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206)
[issues](https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206) when running on Mac, present in previous versions, persist on this release
when running on Mac, present in previous versions, persist on this release and and the data-loaders will not run in parallel. In addition, since `python
the data-loaders will not run in parallel. In addition, since `python 3.8`, 3.8`, [the `multiprocessing` library start method changed from `'fork'` to`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).
[the `multiprocessing` library start method changed from `'fork'` to
`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).
This also affects the data-loaders (for any `torch` version) and they will This also affects the data-loaders (for any `torch` version) and they will
not run in parallel. Therefore, for Mac users I recommend using `python 3.6` not run in parallel. Therefore, for Mac users I recommend using `python
or `3.7` and `torch <= 1.6` (with the corresponding, consistent version of 3.6` or `3.7` and `torch <= 1.6` (with the corresponding, consistent
`torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to force this version of `torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to
versioning in the `setup.py` file since I expect that all these issues are force this versioning in the `setup.py` file since I expect that all these
fixed in the future. Therefore, after installing `pytorch-widedeep` via pip issues are fixed in the future. Therefore, after installing
or directly from github, downgrade `torch` and `torchvision` manually: `pytorch-widedeep` via pip or directly from github, downgrade `torch` and
`torchvision` manually:
```bash ```bash
pip install pytorch-widedeep pip install pytorch-widedeep
......
...@@ -5,4 +5,5 @@ from pytorch_widedeep.models.wide_deep import WideDeep ...@@ -5,4 +5,5 @@ from pytorch_widedeep.models.wide_deep import WideDeep
from pytorch_widedeep.models.deep_image import DeepImage from pytorch_widedeep.models.deep_image import DeepImage
from pytorch_widedeep.models.tab_resnet import TabResnet from pytorch_widedeep.models.tab_resnet import TabResnet
from pytorch_widedeep.models.tabnet.tab_net import TabNet from pytorch_widedeep.models.tabnet.tab_net import TabNet
from pytorch_widedeep.models.tab_transformer import TabTransformer from pytorch_widedeep.models.transformers.saint import SAINT
from pytorch_widedeep.models.transformers.tab_transformer import TabTransformer
此差异已折叠。
此差异已折叠。
此差异已折叠。
...@@ -24,7 +24,7 @@ class Wide(nn.Module): ...@@ -24,7 +24,7 @@ class Wide(nn.Module):
Attributes Attributes
----------- -----------
wide_linear: :obj:`nn.Module` wide_linear: ``nn.Module``
the linear layer that comprises the wide branch of the model the linear layer that comprises the wide branch of the model
Examples Examples
......
...@@ -109,7 +109,9 @@ def pad_sequences( ...@@ -109,7 +109,9 @@ def pad_sequences(
>>> pad_sequences(seq, maxlen=5, pad_idx=0) >>> pad_sequences(seq, maxlen=5, pad_idx=0)
array([0, 0, 1, 2, 3], dtype=int32) array([0, 0, 1, 2, 3], dtype=int32)
""" """
if len(seq) >= maxlen: if len(seq) == 0:
return np.zeros(maxlen, dtype="int32") + pad_idx
elif len(seq) >= maxlen:
res = np.array(seq[-maxlen:]).astype("int32") res = np.array(seq[-maxlen:]).astype("int32")
return res return res
else: else:
......
__version__ = "1.0.0" __version__ = "1.0.5"
...@@ -16,6 +16,7 @@ from typing import ( ...@@ -16,6 +16,7 @@ from typing import (
) )
from pathlib import PosixPath from pathlib import PosixPath
import torch
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
...@@ -51,6 +52,7 @@ from torch.utils.data.dataloader import DataLoader ...@@ -51,6 +52,7 @@ from torch.utils.data.dataloader import DataLoader
from pytorch_widedeep.models import WideDeep from pytorch_widedeep.models import WideDeep
from pytorch_widedeep.models.tabnet.sparsemax import Entmax15, Sparsemax from pytorch_widedeep.models.tabnet.sparsemax import Entmax15, Sparsemax
from pytorch_widedeep.models.transformers.layers import FullEmbeddingDropout
ListRules = Collection[Callable[[str], str]] ListRules = Collection[Callable[[str], str]]
Tokens = Collection[Collection[str]] Tokens = Collection[Collection[str]]
...@@ -83,3 +85,5 @@ Transforms = Union[ ...@@ -83,3 +85,5 @@ Transforms = Union[
] ]
LRScheduler = _LRScheduler LRScheduler = _LRScheduler
ModelParams = Generator[Tensor, Tensor, Tensor] ModelParams = Generator[Tensor, Tensor, Tensor]
NormLayers = Union[torch.nn.Identity, torch.nn.LayerNorm, torch.nn.BatchNorm1d]
DropoutLayers = Union[torch.nn.Dropout, FullEmbeddingDropout]
...@@ -147,7 +147,7 @@ finetuner = FineTune(loss_fn, MultipleMetrics([Accuracy()]), "binary", False) ...@@ -147,7 +147,7 @@ finetuner = FineTune(loss_fn, MultipleMetrics([Accuracy()]), "binary", False)
# so here we go... # so here we go...
last_linear = list(deeptabular.children())[1] last_linear = list(deeptabular.children())[1]
inverted_mlp_layers = list( inverted_mlp_layers = list(
list(list(deeptabular.named_modules())[9][1].children())[0].children() list(list(deeptabular.named_modules())[10][1].children())[0].children()
)[::-1] )[::-1]
tab_layers = [last_linear] + inverted_mlp_layers tab_layers = [last_linear] + inverted_mlp_layers
text_layers = [c for c in list(deeptext.children())[1:]][::-1] text_layers = [c for c in list(deeptext.children())[1:]][::-1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册