提交 5f87bd61 编写于 作者: J jrzaurin

Fixed a series of breaking changes related to the metrics returning arrays....

Fixed a series of breaking changes related to the metrics returning arrays. Adapted docs to the new functionalities
上级 ff9ecdc3
......@@ -13,6 +13,9 @@ Here are the 4 callbacks available in ``pytorch-widedepp``: ``History``,
.. autoclass:: pytorch_widedeep.callbacks.LRShedulerCallback
:members:
.. autoclass:: pytorch_widedeep.callbacks.MetricCallback
:members:
.. autoclass:: pytorch_widedeep.callbacks.LRHistory
:members:
......
Dataloaders
===========
.. note:: This module should contain custom dataloaders that the user might want to
implement. At the moment ``pytorch-widedeep`` offers one custom dataloader,
``DataLoaderImbalanced``.
.. autoclass:: pytorch_widedeep.dataloaders.DataLoaderImbalanced
:members:
:undoc-members:
......@@ -14,3 +14,4 @@ them to address different problems
* `FineTune routines <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/06_FineTune_and_WarmUp_Model_Components.ipynb>`__
* `Custom Components <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/07_Custom_Components.ipynb>`__
* `Save and Load Model and Artifacts <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/08_save_and_load_model_and_artifacts.ipynb>`__
* `Using Custom DataLoaders and Torchmetrics <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/09_Custom_DataLoader_Imbalanced_dataset.ipynb>`__
......@@ -20,6 +20,7 @@ Documentation
Model Components <model_components>
Metrics <metrics>
Losses <losses>
Dataloaders <dataloaders>
Callbacks <callbacks>
The Trainer <trainer>
Examples <examples>
......
......@@ -40,4 +40,5 @@ Dependencies
* torch
* torchvision
* einops
* wrapt
\ No newline at end of file
* wrapt
* torchmetrics
\ No newline at end of file
......@@ -7,6 +7,26 @@ Metrics
ground truth is expected to be a 1D tensor with the corresponding classes.
See Examples below
We have added the possibility of using the metrics available at the
`torchmetrics <https://torchmetrics.readthedocs.io/en/latest/>`_ library.
Note that this library is still in its early versions and therefore this
option should be used with caution. To use ``torchmetrics`` simply import
them and use them as any of the ``pytorch-widedeep`` metrics described
below.
.. code-block:: python
from torchmetrics import Accuracy, Precision
accuracy = Accuracy(average=None, num_classes=2)
precision = Precision(average='micro', num_classes=2)
trainer = Trainer(model, objective="binary", metrics=[accuracy, precision])
A functioning example for ``pytorch-widedeep`` using ``torchmetrics`` can be
found in the `Examples folder <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples>`_.
.. autoclass:: pytorch_widedeep.metrics.Accuracy
:members:
:undoc-members:
......
......@@ -16,4 +16,5 @@ tqdm
torch
torchvision
einops
wrapt
\ No newline at end of file
wrapt
torchmetrics
\ No newline at end of file
Training wide and deep models for tabular data
==============================================
===============================================
`...` or just deep learning models for tabular data.
......
......@@ -32,7 +32,16 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/javier/.pyenv/versions/3.7.7/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n"
]
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
......@@ -66,7 +75,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [
{
......@@ -625,20 +634,20 @@
"4 0.68 -0.59 2.0 -36.0 -6.9 2.02 0.14 -0.23 "
]
},
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"header_list = ['EXAMPLE_ID', 'BLOCK_ID', 'target'] + [str(i) for i in range(4,78)]\n",
"df = pd.read_csv('data/bio_train.dat', sep='\\t', names=header_list)\n",
"df = pd.read_csv('data/kddcup04/bio_train.dat', sep='\\t', names=header_list)\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
......@@ -649,7 +658,7 @@
"Name: target, dtype: int64"
]
},
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
......@@ -661,7 +670,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
......@@ -671,7 +680,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
......@@ -688,7 +697,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
......@@ -697,7 +706,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
......@@ -723,7 +732,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
......@@ -734,7 +743,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"outputs": [
{
......@@ -778,7 +787,7 @@
")"
]
},
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
......@@ -793,7 +802,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
......@@ -810,16 +819,16 @@
"metadata": {},
"outputs": [],
"source": [
"# Metrics from pytorch-widedeep\n",
"accuracy = Accuracy(top_k=2)\n",
"precision = Precision(average=False)\n",
"recall = Recall(average=True)\n",
"f1 = F1Score(average=False)"
"# # Metrics from pytorch-widedeep\n",
"# accuracy = Accuracy(top_k=2)\n",
"# precision = Precision(average=False)\n",
"# recall = Recall(average=True)\n",
"# f1 = F1Score(average=False)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
......@@ -839,26 +848,31 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 208/208 [00:01<00:00, 163.29it/s, loss=0.126, metrics={'Accuracy': [0.9522, 0.9451], 'Precision': 0.9486, 'Recall': [0.9522, 0.9451], 'F1': [0.9482, 0.949]}] \n",
"valid: 100%|██████████| 292/292 [00:01<00:00, 161.79it/s, loss=0.125, metrics={'Accuracy': [0.9466, 0.9302], 'Precision': 0.9464, 'Recall': [0.9466, 0.9302], 'F1': [0.9722, 0.2351]}]\n",
"epoch 2: 100%|██████████| 208/208 [00:01<00:00, 165.78it/s, loss=0.129, metrics={'Accuracy': [0.9579, 0.9431], 'Precision': 0.9504, 'Recall': [0.9579, 0.9431], 'F1': [0.9503, 0.9506]}]\n",
"valid: 100%|██████████| 292/292 [00:01<00:00, 163.66it/s, loss=0.127, metrics={'Accuracy': [0.946, 0.9302], 'Precision': 0.9459, 'Recall': [0.946, 0.9302], 'F1': [0.9719, 0.2332]}] \n",
"epoch 3: 100%|██████████| 208/208 [00:01<00:00, 162.24it/s, loss=0.124, metrics={'Accuracy': [0.9524, 0.9462], 'Precision': 0.9493, 'Recall': [0.9524, 0.9462], 'F1': [0.9491, 0.9495]}]\n",
"valid: 100%|██████████| 292/292 [00:01<00:00, 161.53it/s, loss=0.127, metrics={'Accuracy': [0.9457, 0.9302], 'Precision': 0.9456, 'Recall': [0.9457, 0.9302], 'F1': [0.9718, 0.2323]}]\n"
"epoch 1: 0%| | 0/208 [00:00<?, ?it/s]/Users/javier/.pyenv/versions/3.7.7/envs/widedeep37/lib/python3.7/site-packages/torchmetrics/functional/classification/accuracy.py:90: UserWarning: This overload of nonzero is deprecated:\n",
"\tnonzero(Tensor input, *, Tensor out)\n",
"Consider using one of the following signatures instead:\n",
"\tnonzero(Tensor input, *, bool as_tuple) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:766.)\n",
" meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu()\n",
"epoch 1: 100%|██████████| 208/208 [00:02<00:00, 78.94it/s, loss=0.188, metrics={'Accuracy': [0.9249, 0.9249], 'Precision': 0.9249, 'Recall': [0.9249, 0.9249], 'F1': [0.9244, 0.9253]}] \n",
"valid: 100%|██████████| 292/292 [00:02<00:00, 107.92it/s, loss=0.0857, metrics={'Accuracy': [0.9664, 0.9147], 'Precision': 0.9659, 'Recall': [0.9664, 0.9147], 'F1': [0.9825, 0.322]}] \n",
"epoch 2: 100%|██████████| 208/208 [00:02<00:00, 88.97it/s, loss=0.121, metrics={'Accuracy': [0.9521, 0.9491], 'Precision': 0.9506, 'Recall': [0.9521, 0.9491], 'F1': [0.9512, 0.95]}] \n",
"valid: 100%|██████████| 292/292 [00:02<00:00, 109.12it/s, loss=0.0894, metrics={'Accuracy': [0.9613, 0.9302], 'Precision': 0.961, 'Recall': [0.9613, 0.9302], 'F1': [0.98, 0.297]}] \n",
"epoch 3: 100%|██████████| 208/208 [00:02<00:00, 86.76it/s, loss=0.102, metrics={'Accuracy': [0.9534, 0.9631], 'Precision': 0.9583, 'Recall': [0.9534, 0.9631], 'F1': [0.9576, 0.9591]}]\n",
"valid: 100%|██████████| 292/292 [00:02<00:00, 106.86it/s, loss=0.116, metrics={'Accuracy': [0.9437, 0.9457], 'Precision': 0.9437, 'Recall': [0.9437, 0.9457], 'F1': [0.9708, 0.2293]}]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training time[s]: 0:00:09\n"
"Training time[s]: 0:00:16\n"
]
}
],
......@@ -876,7 +890,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 14,
"metadata": {},
"outputs": [
{
......@@ -915,70 +929,70 @@
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.358859</td>\n",
" <td>[0.78240305, 0.86529005]</td>\n",
" <td>0.8230473</td>\n",
" <td>[0.78240305, 0.86529005]</td>\n",
" <td>[0.8184067, 0.8274565]</td>\n",
" <td>0.145092</td>\n",
" <td>[0.97480273, 0.85271317]</td>\n",
" <td>0.9737221</td>\n",
" <td>[0.97480273, 0.85271317]</td>\n",
" <td>[0.9865836, 0.36484244]</td>\n",
" <td>0.187643</td>\n",
" <td>[0.92486894, 0.9248898]</td>\n",
" <td>0.92487943</td>\n",
" <td>[0.92486894, 0.9248898]</td>\n",
" <td>[0.9244203, 0.92533314]</td>\n",
" <td>0.085667</td>\n",
" <td>[0.96635747, 0.9147287]</td>\n",
" <td>0.96590054</td>\n",
" <td>[0.96635747, 0.9147287]</td>\n",
" <td>[0.98251045, 0.32196453]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.201815</td>\n",
" <td>[0.94106466, 0.90215266]</td>\n",
" <td>0.9218901</td>\n",
" <td>[0.94106466, 0.90215266]</td>\n",
" <td>[0.92436975, 0.9192423]</td>\n",
" <td>0.276596</td>\n",
" <td>[0.878167, 0.9612403]</td>\n",
" <td>0.87890226</td>\n",
" <td>[0.878167, 0.9612403]</td>\n",
" <td>[0.9349596, 0.12319921]</td>\n",
" <td>0.121485</td>\n",
" <td>[0.9521358, 0.9490831]</td>\n",
" <td>0.9506268</td>\n",
" <td>[0.9521358, 0.9490831]</td>\n",
" <td>[0.95122886, 0.95000976]</td>\n",
" <td>0.089378</td>\n",
" <td>[0.9613042, 0.9302326]</td>\n",
" <td>0.9610292</td>\n",
" <td>[0.9613042, 0.9302326]</td>\n",
" <td>[0.9799591, 0.2970297]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.154379</td>\n",
" <td>[0.936721, 0.94180405]</td>\n",
" <td>0.93924785</td>\n",
" <td>[0.936721, 0.94180405]</td>\n",
" <td>[0.9394231, 0.93907154]</td>\n",
" <td>0.100671</td>\n",
" <td>[0.9632424, 0.8992248]</td>\n",
" <td>0.9626758</td>\n",
" <td>[0.9632424, 0.8992248]</td>\n",
" <td>[0.9808275, 0.2989691]</td>\n",
" <td>0.102038</td>\n",
" <td>[0.9534429, 0.96310383]</td>\n",
" <td>0.95834136</td>\n",
" <td>[0.9534429, 0.96310383]</td>\n",
" <td>[0.9575638, 0.95909095]</td>\n",
" <td>0.115516</td>\n",
" <td>[0.9437215, 0.9457364]</td>\n",
" <td>0.9437393</td>\n",
" <td>[0.9437215, 0.9457364]</td>\n",
" <td>[0.97080404, 0.22932333]</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" train_loss train_Accuracy train_Precision \\\n",
"0 0.358859 [0.78240305, 0.86529005] 0.8230473 \n",
"1 0.201815 [0.94106466, 0.90215266] 0.9218901 \n",
"2 0.154379 [0.936721, 0.94180405] 0.93924785 \n",
" train_loss train_Accuracy train_Precision \\\n",
"0 0.187643 [0.92486894, 0.9248898] 0.92487943 \n",
"1 0.121485 [0.9521358, 0.9490831] 0.9506268 \n",
"2 0.102038 [0.9534429, 0.96310383] 0.95834136 \n",
"\n",
" train_Recall train_F1 val_loss \\\n",
"0 [0.78240305, 0.86529005] [0.8184067, 0.8274565] 0.145092 \n",
"1 [0.94106466, 0.90215266] [0.92436975, 0.9192423] 0.276596 \n",
"2 [0.936721, 0.94180405] [0.9394231, 0.93907154] 0.100671 \n",
" train_Recall train_F1 val_loss \\\n",
"0 [0.92486894, 0.9248898] [0.9244203, 0.92533314] 0.085667 \n",
"1 [0.9521358, 0.9490831] [0.95122886, 0.95000976] 0.089378 \n",
"2 [0.9534429, 0.96310383] [0.9575638, 0.95909095] 0.115516 \n",
"\n",
" val_Accuracy val_Precision val_Recall \\\n",
"0 [0.97480273, 0.85271317] 0.9737221 [0.97480273, 0.85271317] \n",
"1 [0.878167, 0.9612403] 0.87890226 [0.878167, 0.9612403] \n",
"2 [0.9632424, 0.8992248] 0.9626758 [0.9632424, 0.8992248] \n",
" val_Accuracy val_Precision val_Recall \\\n",
"0 [0.96635747, 0.9147287] 0.96590054 [0.96635747, 0.9147287] \n",
"1 [0.9613042, 0.9302326] 0.9610292 [0.9613042, 0.9302326] \n",
"2 [0.9437215, 0.9457364] 0.9437393 [0.9437215, 0.9457364] \n",
"\n",
" val_F1 \n",
"0 [0.9865836, 0.36484244] \n",
"1 [0.9349596, 0.12319921] \n",
"2 [0.9808275, 0.2989691] "
" val_F1 \n",
"0 [0.98251045, 0.32196453] \n",
"1 [0.9799591, 0.2970297] \n",
"2 [0.97080404, 0.22932333] "
]
},
"execution_count": 21,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
......@@ -989,14 +1003,14 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"predict: 100%|██████████| 292/292 [00:00<00:00, 368.56it/s]\n"
"predict: 100%|██████████| 292/292 [00:00<00:00, 306.28it/s]\n"
]
},
{
......@@ -1005,15 +1019,15 @@
"text": [
" precision recall f1-score support\n",
"\n",
" 0 1.00 0.96 0.98 14446\n",
" 1 0.18 0.92 0.30 130\n",
" 0 1.00 0.94 0.97 14446\n",
" 1 0.13 0.95 0.23 130\n",
"\n",
" accuracy 0.96 14576\n",
" macro avg 0.59 0.94 0.64 14576\n",
"weighted avg 0.99 0.96 0.97 14576\n",
" accuracy 0.94 14576\n",
" macro avg 0.57 0.95 0.60 14576\n",
"weighted avg 0.99 0.94 0.96 14576\n",
"\n",
"Actual predicted values:\n",
"(array([0, 1]), array([13910, 666]))\n"
"(array([0, 1]), array([13650, 926]))\n"
]
}
],
......
......@@ -11,6 +11,8 @@ from pytorch_widedeep.models import ( # noqa: F401
TabResnet,
)
from pytorch_widedeep.metrics import Accuracy, Precision
# from torchmetrics import Accuracy as accuracy_score
from pytorch_widedeep.callbacks import (
LRHistory,
EarlyStopping,
......@@ -94,6 +96,7 @@ if __name__ == "__main__":
schedulers = {"wide": wide_sch, "deeptabular": deep_sch}
initializers = {"wide": KaimingNormal, "deeptabular": XavierNormal}
callbacks = [early_stopping, model_checkpoint, LRHistory(n_epochs=10)]
# metrics = [Accuracy, accuracy_score(num_classes=2), Precision]
metrics = [Accuracy, Precision]
trainer = Trainer(
......
......@@ -157,6 +157,8 @@ class History(Callback):
):
logs = logs or {}
for k, v in logs.items():
if isinstance(v, np.ndarray):
v = v.tolist()
self.trainer.history.setdefault(k, []).append(v)
......@@ -216,6 +218,12 @@ class LRShedulerCallback(Callback):
class MetricCallback(Callback):
r"""Callback that resets the metrics (if any metric is used)
This callback runs by default within :obj:`Trainer`, therefore, should not
be passed to the :obj:`Trainer`. Is included here just for completion.
"""
def __init__(self, container: MultipleMetrics):
self.container = container
......
......@@ -34,23 +34,18 @@ class DataLoaderDefault(DataLoader):
class DataLoaderImbalanced(DataLoader):
r"""Helper function to load and shuffle tensors into models in batches with
adjusted weights to "fight" against imbalance of the classes. If the
classes do not begin from 0 remapping is necessary, see:
https://towardsdatascience.com/pytorch-tabular-multiclass-classification-9f8211a123ab
r"""Class to load and shuffle batches with adjusted weights for imbalanced
datasets. If the classes do not begin from 0 remapping is necessary. See
`here <https://towardsdatascience.com/pytorch-tabular-multiclass-classification-9f8211a123ab>`_
Parameters
----------
dataset ``WideDeepDataset``:
dataset containing target classes in dataset.Y
dataset: ``WideDeepDataset``
see ``pytorch_widedeep.training._wd_dataset``
batch_size: int
size of batch
num_workers: int
number of workers
Returns:
--------
PyTorch ``DataLoader`` object
"""
def __init__(
......
......@@ -67,13 +67,13 @@ class Accuracy(Metric):
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> acc(y_pred, y_true)
0.5
array(0.5)
>>>
>>> acc = Accuracy(top_k=2)
>>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> acc(y_pred, y_true)
0.6666666666666666
array(0.66666667)
"""
def __init__(self, top_k: int = 1):
......@@ -126,13 +126,13 @@ class Precision(Metric):
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> prec(y_pred, y_true)
0.5
array(0.5)
>>>
>>> prec = Precision(average=True)
>>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> prec(y_pred, y_true)
0.3333333432674408
array(0.33333334)
"""
def __init__(self, average: bool = True):
......@@ -192,13 +192,13 @@ class Recall(Metric):
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> rec(y_pred, y_true)
0.5
array(0.5)
>>>
>>> rec = Recall(average=True)
>>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> rec(y_pred, y_true)
0.3333333432674408
array(0.33333334)
"""
def __init__(self, average: bool = True):
......@@ -262,13 +262,13 @@ class FBetaScore(Metric):
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> fbeta(y_pred, y_true)
0.5
array(0.5)
>>>
>>> fbeta = FBetaScore(beta=2)
>>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> fbeta(y_pred, y_true)
0.3333333432674408
array(0.33333334)
"""
def __init__(self, beta: int, average: bool = True):
......@@ -321,13 +321,13 @@ class F1Score(Metric):
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> f1(y_pred, y_true)
0.5
array(0.5)
>>>
>>> f1 = F1Score()
>>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> f1(y_pred, y_true)
0.3333333432674408
array(0.33333334)
"""
def __init__(self, average: bool = True):
......@@ -367,7 +367,7 @@ class R2Score(Metric):
>>> y_true = torch.tensor([3, -0.5, 2, 7]).view(-1, 1)
>>> y_pred = torch.tensor([2.5, 0.0, 2, 8]).view(-1, 1)
>>> r2(y_pred, y_true)
0.9486081370449679
array(0.94860814)
"""
def __init__(self):
......
......@@ -85,7 +85,7 @@ class Trainer:
function. See for example
:class:`pytorch_widedeep.losses.FocalLoss` for the required
structure of the object or the `Examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
folder in the repo.
.. note:: If ``custom_loss_function`` is not None, ``objective`` must be
......@@ -128,7 +128,7 @@ class Trainer:
callbacks are used by default. This can also be a custom callback as
long as the object of type ``Callback``. See
:obj:`pytorch_widedeep.callbacks.Callback` or the `Examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
folder in the repo
metrics: List, optional, default=None
- List of objects of type :obj:`Metric`. Metrics available are:
......@@ -136,7 +136,7 @@ class Trainer:
``F1Score`` and ``R2Score``. This can also be a custom metric as
long as it is an object of type :obj:`Metric`. See
:obj:`pytorch_widedeep.metrics.Metric` or the `Examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
folder in the repo
- List of objects of type :obj:`torchmetrics.Metric`. This can be any
metric from torchmetrics library `Examples
......@@ -379,11 +379,10 @@ class Trainer:
epochs validation frequency
batch_size: int, default=32
batch size
custom_dataloader: ``torch.utils.data.DataLoader``, Optional, default=None
custom_dataloader: ``DataLoader``, Optional, default=None
object of class ``torch.utils.data.DataLoader``. Available
predefined dataloaders are ``DataLoaderDefault``, ```` in
``pytorch_widedeep.metrics``. If None, a ``DataLoaderDefault``
is set.
predefined dataloaders are in ``pytorch-widedeep.dataloaders``.If
``None``, a standard torch ``DataLoader`` is used.
finetune: bool, default=False
param alias: ``warmup``
......@@ -415,7 +414,7 @@ class Trainer:
For details on how these routines work, please see the Examples
section in this documentation and the `Examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
folder in the repo.
finetune_epochs: int, default=4
param alias: ``warmup_epochs``
......@@ -493,7 +492,7 @@ class Trainer:
--------
For a series of comprehensive examples please, see the `Examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
folder in the repo
For completion, here we include some `"fabricated"` examples, i.e.
......@@ -772,7 +771,7 @@ class Trainer:
--------
For a series of comprehensive examples please, see the `Examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
folder in the repo
For completion, here we include a `"fabricated"` example, i.e.
......@@ -859,7 +858,7 @@ class Trainer:
save_state_dict: bool = False,
model_filename: str = "wd_model.pt",
):
"""Saves the model, training and evaluation history, and the
r"""Saves the model, training and evaluation history, and the
``feature_importance`` attribute (if the ``deeptabular`` component is a
Tabnet model) to disk
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册