提交 5f87bd61 编写于 作者: J jrzaurin

Fixed a series of breaking changes related to the metrics returning arrays....

Fixed a series of breaking changes related to the metrics returning arrays. Adapted docs to the new functionalities
上级 ff9ecdc3
...@@ -13,6 +13,9 @@ Here are the 4 callbacks available in ``pytorch-widedepp``: ``History``, ...@@ -13,6 +13,9 @@ Here are the 4 callbacks available in ``pytorch-widedepp``: ``History``,
.. autoclass:: pytorch_widedeep.callbacks.LRShedulerCallback .. autoclass:: pytorch_widedeep.callbacks.LRShedulerCallback
:members: :members:
.. autoclass:: pytorch_widedeep.callbacks.MetricCallback
:members:
.. autoclass:: pytorch_widedeep.callbacks.LRHistory .. autoclass:: pytorch_widedeep.callbacks.LRHistory
:members: :members:
......
Dataloaders
===========
.. note:: This module should contain custom dataloaders that the user might want to
implement. At the moment ``pytorch-widedeep`` offers one custom dataloader,
``DataLoaderImbalanced``.
.. autoclass:: pytorch_widedeep.dataloaders.DataLoaderImbalanced
:members:
:undoc-members:
...@@ -14,3 +14,4 @@ them to address different problems ...@@ -14,3 +14,4 @@ them to address different problems
* `FineTune routines <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/06_FineTune_and_WarmUp_Model_Components.ipynb>`__ * `FineTune routines <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/06_FineTune_and_WarmUp_Model_Components.ipynb>`__
* `Custom Components <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/07_Custom_Components.ipynb>`__ * `Custom Components <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/07_Custom_Components.ipynb>`__
* `Save and Load Model and Artifacts <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/08_save_and_load_model_and_artifacts.ipynb>`__ * `Save and Load Model and Artifacts <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/08_save_and_load_model_and_artifacts.ipynb>`__
* `Using Custom DataLoaders and Torchmetrics <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/09_Custom_DataLoader_Imbalanced_dataset.ipynb>`__
...@@ -20,6 +20,7 @@ Documentation ...@@ -20,6 +20,7 @@ Documentation
Model Components <model_components> Model Components <model_components>
Metrics <metrics> Metrics <metrics>
Losses <losses> Losses <losses>
Dataloaders <dataloaders>
Callbacks <callbacks> Callbacks <callbacks>
The Trainer <trainer> The Trainer <trainer>
Examples <examples> Examples <examples>
......
...@@ -40,4 +40,5 @@ Dependencies ...@@ -40,4 +40,5 @@ Dependencies
* torch * torch
* torchvision * torchvision
* einops * einops
* wrapt * wrapt
\ No newline at end of file * torchmetrics
\ No newline at end of file
...@@ -7,6 +7,26 @@ Metrics ...@@ -7,6 +7,26 @@ Metrics
ground truth is expected to be a 1D tensor with the corresponding classes. ground truth is expected to be a 1D tensor with the corresponding classes.
See Examples below See Examples below
We have added the possibility of using the metrics available at the
`torchmetrics <https://torchmetrics.readthedocs.io/en/latest/>`_ library.
Note that this library is still in its early versions and therefore this
option should be used with caution. To use ``torchmetrics`` simply import
them and use them as any of the ``pytorch-widedeep`` metrics described
below.
.. code-block:: python
from torchmetrics import Accuracy, Precision
accuracy = Accuracy(average=None, num_classes=2)
precision = Precision(average='micro', num_classes=2)
trainer = Trainer(model, objective="binary", metrics=[accuracy, precision])
A functioning example for ``pytorch-widedeep`` using ``torchmetrics`` can be
found in the `Examples folder <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples>`_.
.. autoclass:: pytorch_widedeep.metrics.Accuracy .. autoclass:: pytorch_widedeep.metrics.Accuracy
:members: :members:
:undoc-members: :undoc-members:
......
...@@ -16,4 +16,5 @@ tqdm ...@@ -16,4 +16,5 @@ tqdm
torch torch
torchvision torchvision
einops einops
wrapt wrapt
\ No newline at end of file torchmetrics
\ No newline at end of file
Training wide and deep models for tabular data Training wide and deep models for tabular data
============================================== ===============================================
`...` or just deep learning models for tabular data. `...` or just deep learning models for tabular data.
......
...@@ -32,7 +32,16 @@ ...@@ -32,7 +32,16 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/javier/.pyenv/versions/3.7.7/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n"
]
}
],
"source": [ "source": [
"import numpy as np\n", "import numpy as np\n",
"import pandas as pd\n", "import pandas as pd\n",
...@@ -66,7 +75,7 @@ ...@@ -66,7 +75,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -625,20 +634,20 @@ ...@@ -625,20 +634,20 @@
"4 0.68 -0.59 2.0 -36.0 -6.9 2.02 0.14 -0.23 " "4 0.68 -0.59 2.0 -36.0 -6.9 2.02 0.14 -0.23 "
] ]
}, },
"execution_count": 2, "execution_count": 3,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"header_list = ['EXAMPLE_ID', 'BLOCK_ID', 'target'] + [str(i) for i in range(4,78)]\n", "header_list = ['EXAMPLE_ID', 'BLOCK_ID', 'target'] + [str(i) for i in range(4,78)]\n",
"df = pd.read_csv('data/bio_train.dat', sep='\\t', names=header_list)\n", "df = pd.read_csv('data/kddcup04/bio_train.dat', sep='\\t', names=header_list)\n",
"df.head()" "df.head()"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -649,7 +658,7 @@ ...@@ -649,7 +658,7 @@
"Name: target, dtype: int64" "Name: target, dtype: int64"
] ]
}, },
"execution_count": 3, "execution_count": 4,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -661,7 +670,7 @@ ...@@ -661,7 +670,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -671,7 +680,7 @@ ...@@ -671,7 +680,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -688,7 +697,7 @@ ...@@ -688,7 +697,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -697,7 +706,7 @@ ...@@ -697,7 +706,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -723,7 +732,7 @@ ...@@ -723,7 +732,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -734,7 +743,7 @@ ...@@ -734,7 +743,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -778,7 +787,7 @@ ...@@ -778,7 +787,7 @@
")" ")"
] ]
}, },
"execution_count": 9, "execution_count": 10,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -793,7 +802,7 @@ ...@@ -793,7 +802,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 11,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -810,16 +819,16 @@ ...@@ -810,16 +819,16 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Metrics from pytorch-widedeep\n", "# # Metrics from pytorch-widedeep\n",
"accuracy = Accuracy(top_k=2)\n", "# accuracy = Accuracy(top_k=2)\n",
"precision = Precision(average=False)\n", "# precision = Precision(average=False)\n",
"recall = Recall(average=True)\n", "# recall = Recall(average=True)\n",
"f1 = F1Score(average=False)" "# f1 = F1Score(average=False)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 12,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -839,26 +848,31 @@ ...@@ -839,26 +848,31 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 13,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 208/208 [00:01<00:00, 163.29it/s, loss=0.126, metrics={'Accuracy': [0.9522, 0.9451], 'Precision': 0.9486, 'Recall': [0.9522, 0.9451], 'F1': [0.9482, 0.949]}] \n", "epoch 1: 0%| | 0/208 [00:00<?, ?it/s]/Users/javier/.pyenv/versions/3.7.7/envs/widedeep37/lib/python3.7/site-packages/torchmetrics/functional/classification/accuracy.py:90: UserWarning: This overload of nonzero is deprecated:\n",
"valid: 100%|██████████| 292/292 [00:01<00:00, 161.79it/s, loss=0.125, metrics={'Accuracy': [0.9466, 0.9302], 'Precision': 0.9464, 'Recall': [0.9466, 0.9302], 'F1': [0.9722, 0.2351]}]\n", "\tnonzero(Tensor input, *, Tensor out)\n",
"epoch 2: 100%|██████████| 208/208 [00:01<00:00, 165.78it/s, loss=0.129, metrics={'Accuracy': [0.9579, 0.9431], 'Precision': 0.9504, 'Recall': [0.9579, 0.9431], 'F1': [0.9503, 0.9506]}]\n", "Consider using one of the following signatures instead:\n",
"valid: 100%|██████████| 292/292 [00:01<00:00, 163.66it/s, loss=0.127, metrics={'Accuracy': [0.946, 0.9302], 'Precision': 0.9459, 'Recall': [0.946, 0.9302], 'F1': [0.9719, 0.2332]}] \n", "\tnonzero(Tensor input, *, bool as_tuple) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:766.)\n",
"epoch 3: 100%|██████████| 208/208 [00:01<00:00, 162.24it/s, loss=0.124, metrics={'Accuracy': [0.9524, 0.9462], 'Precision': 0.9493, 'Recall': [0.9524, 0.9462], 'F1': [0.9491, 0.9495]}]\n", " meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu()\n",
"valid: 100%|██████████| 292/292 [00:01<00:00, 161.53it/s, loss=0.127, metrics={'Accuracy': [0.9457, 0.9302], 'Precision': 0.9456, 'Recall': [0.9457, 0.9302], 'F1': [0.9718, 0.2323]}]\n" "epoch 1: 100%|██████████| 208/208 [00:02<00:00, 78.94it/s, loss=0.188, metrics={'Accuracy': [0.9249, 0.9249], 'Precision': 0.9249, 'Recall': [0.9249, 0.9249], 'F1': [0.9244, 0.9253]}] \n",
"valid: 100%|██████████| 292/292 [00:02<00:00, 107.92it/s, loss=0.0857, metrics={'Accuracy': [0.9664, 0.9147], 'Precision': 0.9659, 'Recall': [0.9664, 0.9147], 'F1': [0.9825, 0.322]}] \n",
"epoch 2: 100%|██████████| 208/208 [00:02<00:00, 88.97it/s, loss=0.121, metrics={'Accuracy': [0.9521, 0.9491], 'Precision': 0.9506, 'Recall': [0.9521, 0.9491], 'F1': [0.9512, 0.95]}] \n",
"valid: 100%|██████████| 292/292 [00:02<00:00, 109.12it/s, loss=0.0894, metrics={'Accuracy': [0.9613, 0.9302], 'Precision': 0.961, 'Recall': [0.9613, 0.9302], 'F1': [0.98, 0.297]}] \n",
"epoch 3: 100%|██████████| 208/208 [00:02<00:00, 86.76it/s, loss=0.102, metrics={'Accuracy': [0.9534, 0.9631], 'Precision': 0.9583, 'Recall': [0.9534, 0.9631], 'F1': [0.9576, 0.9591]}]\n",
"valid: 100%|██████████| 292/292 [00:02<00:00, 106.86it/s, loss=0.116, metrics={'Accuracy': [0.9437, 0.9457], 'Precision': 0.9437, 'Recall': [0.9437, 0.9457], 'F1': [0.9708, 0.2293]}]\n"
] ]
}, },
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Training time[s]: 0:00:09\n" "Training time[s]: 0:00:16\n"
] ]
} }
], ],
...@@ -876,7 +890,7 @@ ...@@ -876,7 +890,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 21, "execution_count": 14,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -915,70 +929,70 @@ ...@@ -915,70 +929,70 @@
" <tbody>\n", " <tbody>\n",
" <tr>\n", " <tr>\n",
" <th>0</th>\n", " <th>0</th>\n",
" <td>0.358859</td>\n", " <td>0.187643</td>\n",
" <td>[0.78240305, 0.86529005]</td>\n", " <td>[0.92486894, 0.9248898]</td>\n",
" <td>0.8230473</td>\n", " <td>0.92487943</td>\n",
" <td>[0.78240305, 0.86529005]</td>\n", " <td>[0.92486894, 0.9248898]</td>\n",
" <td>[0.8184067, 0.8274565]</td>\n", " <td>[0.9244203, 0.92533314]</td>\n",
" <td>0.145092</td>\n", " <td>0.085667</td>\n",
" <td>[0.97480273, 0.85271317]</td>\n", " <td>[0.96635747, 0.9147287]</td>\n",
" <td>0.9737221</td>\n", " <td>0.96590054</td>\n",
" <td>[0.97480273, 0.85271317]</td>\n", " <td>[0.96635747, 0.9147287]</td>\n",
" <td>[0.9865836, 0.36484244]</td>\n", " <td>[0.98251045, 0.32196453]</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <th>1</th>\n", " <th>1</th>\n",
" <td>0.201815</td>\n", " <td>0.121485</td>\n",
" <td>[0.94106466, 0.90215266]</td>\n", " <td>[0.9521358, 0.9490831]</td>\n",
" <td>0.9218901</td>\n", " <td>0.9506268</td>\n",
" <td>[0.94106466, 0.90215266]</td>\n", " <td>[0.9521358, 0.9490831]</td>\n",
" <td>[0.92436975, 0.9192423]</td>\n", " <td>[0.95122886, 0.95000976]</td>\n",
" <td>0.276596</td>\n", " <td>0.089378</td>\n",
" <td>[0.878167, 0.9612403]</td>\n", " <td>[0.9613042, 0.9302326]</td>\n",
" <td>0.87890226</td>\n", " <td>0.9610292</td>\n",
" <td>[0.878167, 0.9612403]</td>\n", " <td>[0.9613042, 0.9302326]</td>\n",
" <td>[0.9349596, 0.12319921]</td>\n", " <td>[0.9799591, 0.2970297]</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <th>2</th>\n", " <th>2</th>\n",
" <td>0.154379</td>\n", " <td>0.102038</td>\n",
" <td>[0.936721, 0.94180405]</td>\n", " <td>[0.9534429, 0.96310383]</td>\n",
" <td>0.93924785</td>\n", " <td>0.95834136</td>\n",
" <td>[0.936721, 0.94180405]</td>\n", " <td>[0.9534429, 0.96310383]</td>\n",
" <td>[0.9394231, 0.93907154]</td>\n", " <td>[0.9575638, 0.95909095]</td>\n",
" <td>0.100671</td>\n", " <td>0.115516</td>\n",
" <td>[0.9632424, 0.8992248]</td>\n", " <td>[0.9437215, 0.9457364]</td>\n",
" <td>0.9626758</td>\n", " <td>0.9437393</td>\n",
" <td>[0.9632424, 0.8992248]</td>\n", " <td>[0.9437215, 0.9457364]</td>\n",
" <td>[0.9808275, 0.2989691]</td>\n", " <td>[0.97080404, 0.22932333]</td>\n",
" </tr>\n", " </tr>\n",
" </tbody>\n", " </tbody>\n",
"</table>\n", "</table>\n",
"</div>" "</div>"
], ],
"text/plain": [ "text/plain": [
" train_loss train_Accuracy train_Precision \\\n", " train_loss train_Accuracy train_Precision \\\n",
"0 0.358859 [0.78240305, 0.86529005] 0.8230473 \n", "0 0.187643 [0.92486894, 0.9248898] 0.92487943 \n",
"1 0.201815 [0.94106466, 0.90215266] 0.9218901 \n", "1 0.121485 [0.9521358, 0.9490831] 0.9506268 \n",
"2 0.154379 [0.936721, 0.94180405] 0.93924785 \n", "2 0.102038 [0.9534429, 0.96310383] 0.95834136 \n",
"\n", "\n",
" train_Recall train_F1 val_loss \\\n", " train_Recall train_F1 val_loss \\\n",
"0 [0.78240305, 0.86529005] [0.8184067, 0.8274565] 0.145092 \n", "0 [0.92486894, 0.9248898] [0.9244203, 0.92533314] 0.085667 \n",
"1 [0.94106466, 0.90215266] [0.92436975, 0.9192423] 0.276596 \n", "1 [0.9521358, 0.9490831] [0.95122886, 0.95000976] 0.089378 \n",
"2 [0.936721, 0.94180405] [0.9394231, 0.93907154] 0.100671 \n", "2 [0.9534429, 0.96310383] [0.9575638, 0.95909095] 0.115516 \n",
"\n", "\n",
" val_Accuracy val_Precision val_Recall \\\n", " val_Accuracy val_Precision val_Recall \\\n",
"0 [0.97480273, 0.85271317] 0.9737221 [0.97480273, 0.85271317] \n", "0 [0.96635747, 0.9147287] 0.96590054 [0.96635747, 0.9147287] \n",
"1 [0.878167, 0.9612403] 0.87890226 [0.878167, 0.9612403] \n", "1 [0.9613042, 0.9302326] 0.9610292 [0.9613042, 0.9302326] \n",
"2 [0.9632424, 0.8992248] 0.9626758 [0.9632424, 0.8992248] \n", "2 [0.9437215, 0.9457364] 0.9437393 [0.9437215, 0.9457364] \n",
"\n", "\n",
" val_F1 \n", " val_F1 \n",
"0 [0.9865836, 0.36484244] \n", "0 [0.98251045, 0.32196453] \n",
"1 [0.9349596, 0.12319921] \n", "1 [0.9799591, 0.2970297] \n",
"2 [0.9808275, 0.2989691] " "2 [0.97080404, 0.22932333] "
] ]
}, },
"execution_count": 21, "execution_count": 14,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -989,14 +1003,14 @@ ...@@ -989,14 +1003,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22, "execution_count": 15,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"predict: 100%|██████████| 292/292 [00:00<00:00, 368.56it/s]\n" "predict: 100%|██████████| 292/292 [00:00<00:00, 306.28it/s]\n"
] ]
}, },
{ {
...@@ -1005,15 +1019,15 @@ ...@@ -1005,15 +1019,15 @@
"text": [ "text": [
" precision recall f1-score support\n", " precision recall f1-score support\n",
"\n", "\n",
" 0 1.00 0.96 0.98 14446\n", " 0 1.00 0.94 0.97 14446\n",
" 1 0.18 0.92 0.30 130\n", " 1 0.13 0.95 0.23 130\n",
"\n", "\n",
" accuracy 0.96 14576\n", " accuracy 0.94 14576\n",
" macro avg 0.59 0.94 0.64 14576\n", " macro avg 0.57 0.95 0.60 14576\n",
"weighted avg 0.99 0.96 0.97 14576\n", "weighted avg 0.99 0.94 0.96 14576\n",
"\n", "\n",
"Actual predicted values:\n", "Actual predicted values:\n",
"(array([0, 1]), array([13910, 666]))\n" "(array([0, 1]), array([13650, 926]))\n"
] ]
} }
], ],
......
...@@ -11,6 +11,8 @@ from pytorch_widedeep.models import ( # noqa: F401 ...@@ -11,6 +11,8 @@ from pytorch_widedeep.models import ( # noqa: F401
TabResnet, TabResnet,
) )
from pytorch_widedeep.metrics import Accuracy, Precision from pytorch_widedeep.metrics import Accuracy, Precision
# from torchmetrics import Accuracy as accuracy_score
from pytorch_widedeep.callbacks import ( from pytorch_widedeep.callbacks import (
LRHistory, LRHistory,
EarlyStopping, EarlyStopping,
...@@ -94,6 +96,7 @@ if __name__ == "__main__": ...@@ -94,6 +96,7 @@ if __name__ == "__main__":
schedulers = {"wide": wide_sch, "deeptabular": deep_sch} schedulers = {"wide": wide_sch, "deeptabular": deep_sch}
initializers = {"wide": KaimingNormal, "deeptabular": XavierNormal} initializers = {"wide": KaimingNormal, "deeptabular": XavierNormal}
callbacks = [early_stopping, model_checkpoint, LRHistory(n_epochs=10)] callbacks = [early_stopping, model_checkpoint, LRHistory(n_epochs=10)]
# metrics = [Accuracy, accuracy_score(num_classes=2), Precision]
metrics = [Accuracy, Precision] metrics = [Accuracy, Precision]
trainer = Trainer( trainer = Trainer(
......
...@@ -157,6 +157,8 @@ class History(Callback): ...@@ -157,6 +157,8 @@ class History(Callback):
): ):
logs = logs or {} logs = logs or {}
for k, v in logs.items(): for k, v in logs.items():
if isinstance(v, np.ndarray):
v = v.tolist()
self.trainer.history.setdefault(k, []).append(v) self.trainer.history.setdefault(k, []).append(v)
...@@ -216,6 +218,12 @@ class LRShedulerCallback(Callback): ...@@ -216,6 +218,12 @@ class LRShedulerCallback(Callback):
class MetricCallback(Callback): class MetricCallback(Callback):
r"""Callback that resets the metrics (if any metric is used)
This callback runs by default within :obj:`Trainer`, therefore, should not
be passed to the :obj:`Trainer`. Is included here just for completion.
"""
def __init__(self, container: MultipleMetrics): def __init__(self, container: MultipleMetrics):
self.container = container self.container = container
......
...@@ -34,23 +34,18 @@ class DataLoaderDefault(DataLoader): ...@@ -34,23 +34,18 @@ class DataLoaderDefault(DataLoader):
class DataLoaderImbalanced(DataLoader): class DataLoaderImbalanced(DataLoader):
r"""Helper function to load and shuffle tensors into models in batches with r"""Class to load and shuffle batches with adjusted weights for imbalanced
adjusted weights to "fight" against imbalance of the classes. If the datasets. If the classes do not begin from 0 remapping is necessary. See
classes do not begin from 0 remapping is necessary, see: `here <https://towardsdatascience.com/pytorch-tabular-multiclass-classification-9f8211a123ab>`_
https://towardsdatascience.com/pytorch-tabular-multiclass-classification-9f8211a123ab
Parameters Parameters
---------- ----------
dataset ``WideDeepDataset``: dataset: ``WideDeepDataset``
dataset containing target classes in dataset.Y see ``pytorch_widedeep.training._wd_dataset``
batch_size: int batch_size: int
size of batch size of batch
num_workers: int num_workers: int
number of workers number of workers
Returns:
--------
PyTorch ``DataLoader`` object
""" """
def __init__( def __init__(
......
...@@ -67,13 +67,13 @@ class Accuracy(Metric): ...@@ -67,13 +67,13 @@ class Accuracy(Metric):
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1) >>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1) >>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> acc(y_pred, y_true) >>> acc(y_pred, y_true)
0.5 array(0.5)
>>> >>>
>>> acc = Accuracy(top_k=2) >>> acc = Accuracy(top_k=2)
>>> y_true = torch.tensor([0, 1, 2]) >>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]]) >>> y_pred = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> acc(y_pred, y_true) >>> acc(y_pred, y_true)
0.6666666666666666 array(0.66666667)
""" """
def __init__(self, top_k: int = 1): def __init__(self, top_k: int = 1):
...@@ -126,13 +126,13 @@ class Precision(Metric): ...@@ -126,13 +126,13 @@ class Precision(Metric):
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1) >>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1) >>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> prec(y_pred, y_true) >>> prec(y_pred, y_true)
0.5 array(0.5)
>>> >>>
>>> prec = Precision(average=True) >>> prec = Precision(average=True)
>>> y_true = torch.tensor([0, 1, 2]) >>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]]) >>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> prec(y_pred, y_true) >>> prec(y_pred, y_true)
0.3333333432674408 array(0.33333334)
""" """
def __init__(self, average: bool = True): def __init__(self, average: bool = True):
...@@ -192,13 +192,13 @@ class Recall(Metric): ...@@ -192,13 +192,13 @@ class Recall(Metric):
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1) >>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1) >>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> rec(y_pred, y_true) >>> rec(y_pred, y_true)
0.5 array(0.5)
>>> >>>
>>> rec = Recall(average=True) >>> rec = Recall(average=True)
>>> y_true = torch.tensor([0, 1, 2]) >>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]]) >>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> rec(y_pred, y_true) >>> rec(y_pred, y_true)
0.3333333432674408 array(0.33333334)
""" """
def __init__(self, average: bool = True): def __init__(self, average: bool = True):
...@@ -262,13 +262,13 @@ class FBetaScore(Metric): ...@@ -262,13 +262,13 @@ class FBetaScore(Metric):
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1) >>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1) >>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> fbeta(y_pred, y_true) >>> fbeta(y_pred, y_true)
0.5 array(0.5)
>>> >>>
>>> fbeta = FBetaScore(beta=2) >>> fbeta = FBetaScore(beta=2)
>>> y_true = torch.tensor([0, 1, 2]) >>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]]) >>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> fbeta(y_pred, y_true) >>> fbeta(y_pred, y_true)
0.3333333432674408 array(0.33333334)
""" """
def __init__(self, beta: int, average: bool = True): def __init__(self, beta: int, average: bool = True):
...@@ -321,13 +321,13 @@ class F1Score(Metric): ...@@ -321,13 +321,13 @@ class F1Score(Metric):
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1) >>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1) >>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> f1(y_pred, y_true) >>> f1(y_pred, y_true)
0.5 array(0.5)
>>> >>>
>>> f1 = F1Score() >>> f1 = F1Score()
>>> y_true = torch.tensor([0, 1, 2]) >>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]]) >>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> f1(y_pred, y_true) >>> f1(y_pred, y_true)
0.3333333432674408 array(0.33333334)
""" """
def __init__(self, average: bool = True): def __init__(self, average: bool = True):
...@@ -367,7 +367,7 @@ class R2Score(Metric): ...@@ -367,7 +367,7 @@ class R2Score(Metric):
>>> y_true = torch.tensor([3, -0.5, 2, 7]).view(-1, 1) >>> y_true = torch.tensor([3, -0.5, 2, 7]).view(-1, 1)
>>> y_pred = torch.tensor([2.5, 0.0, 2, 8]).view(-1, 1) >>> y_pred = torch.tensor([2.5, 0.0, 2, 8]).view(-1, 1)
>>> r2(y_pred, y_true) >>> r2(y_pred, y_true)
0.9486081370449679 array(0.94860814)
""" """
def __init__(self): def __init__(self):
......
...@@ -85,7 +85,7 @@ class Trainer: ...@@ -85,7 +85,7 @@ class Trainer:
function. See for example function. See for example
:class:`pytorch_widedeep.losses.FocalLoss` for the required :class:`pytorch_widedeep.losses.FocalLoss` for the required
structure of the object or the `Examples structure of the object or the `Examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_ <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
folder in the repo. folder in the repo.
.. note:: If ``custom_loss_function`` is not None, ``objective`` must be .. note:: If ``custom_loss_function`` is not None, ``objective`` must be
...@@ -128,7 +128,7 @@ class Trainer: ...@@ -128,7 +128,7 @@ class Trainer:
callbacks are used by default. This can also be a custom callback as callbacks are used by default. This can also be a custom callback as
long as the object of type ``Callback``. See long as the object of type ``Callback``. See
:obj:`pytorch_widedeep.callbacks.Callback` or the `Examples :obj:`pytorch_widedeep.callbacks.Callback` or the `Examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_ <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
folder in the repo folder in the repo
metrics: List, optional, default=None metrics: List, optional, default=None
- List of objects of type :obj:`Metric`. Metrics available are: - List of objects of type :obj:`Metric`. Metrics available are:
...@@ -136,7 +136,7 @@ class Trainer: ...@@ -136,7 +136,7 @@ class Trainer:
``F1Score`` and ``R2Score``. This can also be a custom metric as ``F1Score`` and ``R2Score``. This can also be a custom metric as
long as it is an object of type :obj:`Metric`. See long as it is an object of type :obj:`Metric`. See
:obj:`pytorch_widedeep.metrics.Metric` or the `Examples :obj:`pytorch_widedeep.metrics.Metric` or the `Examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_ <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
folder in the repo folder in the repo
- List of objects of type :obj:`torchmetrics.Metric`. This can be any - List of objects of type :obj:`torchmetrics.Metric`. This can be any
metric from torchmetrics library `Examples metric from torchmetrics library `Examples
...@@ -379,11 +379,10 @@ class Trainer: ...@@ -379,11 +379,10 @@ class Trainer:
epochs validation frequency epochs validation frequency
batch_size: int, default=32 batch_size: int, default=32
batch size batch size
custom_dataloader: ``torch.utils.data.DataLoader``, Optional, default=None custom_dataloader: ``DataLoader``, Optional, default=None
object of class ``torch.utils.data.DataLoader``. Available object of class ``torch.utils.data.DataLoader``. Available
predefined dataloaders are ``DataLoaderDefault``, ```` in predefined dataloaders are in ``pytorch-widedeep.dataloaders``.If
``pytorch_widedeep.metrics``. If None, a ``DataLoaderDefault`` ``None``, a standard torch ``DataLoader`` is used.
is set.
finetune: bool, default=False finetune: bool, default=False
param alias: ``warmup`` param alias: ``warmup``
...@@ -415,7 +414,7 @@ class Trainer: ...@@ -415,7 +414,7 @@ class Trainer:
For details on how these routines work, please see the Examples For details on how these routines work, please see the Examples
section in this documentation and the `Examples section in this documentation and the `Examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_ <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
folder in the repo. folder in the repo.
finetune_epochs: int, default=4 finetune_epochs: int, default=4
param alias: ``warmup_epochs`` param alias: ``warmup_epochs``
...@@ -493,7 +492,7 @@ class Trainer: ...@@ -493,7 +492,7 @@ class Trainer:
-------- --------
For a series of comprehensive examples please, see the `Examples For a series of comprehensive examples please, see the `Examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_ <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
folder in the repo folder in the repo
For completion, here we include some `"fabricated"` examples, i.e. For completion, here we include some `"fabricated"` examples, i.e.
...@@ -772,7 +771,7 @@ class Trainer: ...@@ -772,7 +771,7 @@ class Trainer:
-------- --------
For a series of comprehensive examples please, see the `Examples For a series of comprehensive examples please, see the `Examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_ <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
folder in the repo folder in the repo
For completion, here we include a `"fabricated"` example, i.e. For completion, here we include a `"fabricated"` example, i.e.
...@@ -859,7 +858,7 @@ class Trainer: ...@@ -859,7 +858,7 @@ class Trainer:
save_state_dict: bool = False, save_state_dict: bool = False,
model_filename: str = "wd_model.pt", model_filename: str = "wd_model.pt",
): ):
"""Saves the model, training and evaluation history, and the r"""Saves the model, training and evaluation history, and the
``feature_importance`` attribute (if the ``deeptabular`` component is a ``feature_importance`` attribute (if the ``deeptabular`` component is a
Tabnet model) to disk Tabnet model) to disk
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册