diff --git a/docs/callbacks.rst b/docs/callbacks.rst index 05c4c9602a5a878a32bf9de7855c84ae74bcfb3f..c5ef2ef20692d6462f1b59bbb2d32d88cabeb4a2 100644 --- a/docs/callbacks.rst +++ b/docs/callbacks.rst @@ -13,6 +13,9 @@ Here are the 4 callbacks available in ``pytorch-widedepp``: ``History``, .. autoclass:: pytorch_widedeep.callbacks.LRShedulerCallback :members: +.. autoclass:: pytorch_widedeep.callbacks.MetricCallback + :members: + .. autoclass:: pytorch_widedeep.callbacks.LRHistory :members: diff --git a/docs/dataloaders.rst b/docs/dataloaders.rst new file mode 100644 index 0000000000000000000000000000000000000000..9d80a390c7ff06093dd50841a1fe06ab8f67ffd9 --- /dev/null +++ b/docs/dataloaders.rst @@ -0,0 +1,11 @@ +Dataloaders +=========== + +.. note:: This module should contain custom dataloaders that the user might want to + implement. At the moment ``pytorch-widedeep`` offers one custom dataloader, + ``DataLoaderImbalanced``. + + +.. autoclass:: pytorch_widedeep.dataloaders.DataLoaderImbalanced + :members: + :undoc-members: diff --git a/docs/examples.rst b/docs/examples.rst index 4f878e9e608db9e4f3d8ec7cb977f6a49d04ac75..2ddbf04459f150243329911a05a0413a72ca495b 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -14,3 +14,4 @@ them to address different problems * `FineTune routines `__ * `Custom Components `__ * `Save and Load Model and Artifacts `__ +* `Using Custom DataLoaders and Torchmetrics `__ diff --git a/docs/index.rst b/docs/index.rst index 9077081cd4ffbeb55da145c016bf4ea3efaf7330..96659654795ede993088ef4ec9b20dd14920334e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -20,6 +20,7 @@ Documentation Model Components Metrics Losses + Dataloaders Callbacks The Trainer Examples diff --git a/docs/installation.rst b/docs/installation.rst index 7c89ba623cb7a0517a1a49831a629cc6331540e6..11500c0f9855cbf731ddf2508e24fcde96788fff 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -40,4 +40,5 @@ Dependencies * torch * torchvision * einops -* wrapt \ No newline at end of file +* wrapt +* torchmetrics \ No newline at end of file diff --git a/docs/metrics.rst b/docs/metrics.rst index 33c6482200f5a1358b5e38d340b6db7a21c840eb..4869f465cda0f86358f8272bd1ac931ca15f0583 100644 --- a/docs/metrics.rst +++ b/docs/metrics.rst @@ -7,6 +7,26 @@ Metrics ground truth is expected to be a 1D tensor with the corresponding classes. See Examples below +We have added the possibility of using the metrics available at the +`torchmetrics `_ library. +Note that this library is still in its early versions and therefore this +option should be used with caution. To use ``torchmetrics`` simply import +them and use them as any of the ``pytorch-widedeep`` metrics described +below. + +.. code-block:: python + + from torchmetrics import Accuracy, Precision + + accuracy = Accuracy(average=None, num_classes=2) + precision = Precision(average='micro', num_classes=2) + + trainer = Trainer(model, objective="binary", metrics=[accuracy, precision]) + +A functioning example for ``pytorch-widedeep`` using ``torchmetrics`` can be +found in the `Examples folder `_. + + .. autoclass:: pytorch_widedeep.metrics.Accuracy :members: :undoc-members: diff --git a/docs/requirements.txt b/docs/requirements.txt index 74da329ca6230ab9f1101f36f1584b6206837c9b..bab197235030907c1efdbe9be159c0b030b0f7f9 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -16,4 +16,5 @@ tqdm torch torchvision einops -wrapt \ No newline at end of file +wrapt +torchmetrics \ No newline at end of file diff --git a/docs/trainer.rst b/docs/trainer.rst index 7483c9bb0032034d63dd3e34e0cd485c02366044..f203569f907f74148aa10b4157f0ef5788347117 100644 --- a/docs/trainer.rst +++ b/docs/trainer.rst @@ -1,5 +1,5 @@ Training wide and deep models for tabular data -============================================== +=============================================== `...` or just deep learning models for tabular data. diff --git a/examples/09_Custom_DataLoader_Imbalanced_dataset.ipynb b/examples/09_Custom_DataLoader_Imbalanced_dataset.ipynb index 34e4c3adc58ad5ce766db230928015c3fe5f3396..d0b2370676a736f0a26b6a5b737f9048d658092b 100644 --- a/examples/09_Custom_DataLoader_Imbalanced_dataset.ipynb +++ b/examples/09_Custom_DataLoader_Imbalanced_dataset.ipynb @@ -32,7 +32,16 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/javier/.pyenv/versions/3.7.7/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n", + " return f(*args, **kwds)\n" + ] + } + ], "source": [ "import numpy as np\n", "import pandas as pd\n", @@ -66,7 +75,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -625,20 +634,20 @@ "4 0.68 -0.59 2.0 -36.0 -6.9 2.02 0.14 -0.23 " ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "header_list = ['EXAMPLE_ID', 'BLOCK_ID', 'target'] + [str(i) for i in range(4,78)]\n", - "df = pd.read_csv('data/bio_train.dat', sep='\\t', names=header_list)\n", + "df = pd.read_csv('data/kddcup04/bio_train.dat', sep='\\t', names=header_list)\n", "df.head()" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -649,7 +658,7 @@ "Name: target, dtype: int64" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -661,7 +670,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -671,7 +680,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -688,7 +697,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -697,7 +706,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -723,7 +732,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -734,7 +743,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -778,7 +787,7 @@ ")" ] }, - "execution_count": 9, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -793,7 +802,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -810,16 +819,16 @@ "metadata": {}, "outputs": [], "source": [ - "# Metrics from pytorch-widedeep\n", - "accuracy = Accuracy(top_k=2)\n", - "precision = Precision(average=False)\n", - "recall = Recall(average=True)\n", - "f1 = F1Score(average=False)" + "# # Metrics from pytorch-widedeep\n", + "# accuracy = Accuracy(top_k=2)\n", + "# precision = Precision(average=False)\n", + "# recall = Recall(average=True)\n", + "# f1 = F1Score(average=False)" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -839,26 +848,31 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 208/208 [00:01<00:00, 163.29it/s, loss=0.126, metrics={'Accuracy': [0.9522, 0.9451], 'Precision': 0.9486, 'Recall': [0.9522, 0.9451], 'F1': [0.9482, 0.949]}] \n", - "valid: 100%|██████████| 292/292 [00:01<00:00, 161.79it/s, loss=0.125, metrics={'Accuracy': [0.9466, 0.9302], 'Precision': 0.9464, 'Recall': [0.9466, 0.9302], 'F1': [0.9722, 0.2351]}]\n", - "epoch 2: 100%|██████████| 208/208 [00:01<00:00, 165.78it/s, loss=0.129, metrics={'Accuracy': [0.9579, 0.9431], 'Precision': 0.9504, 'Recall': [0.9579, 0.9431], 'F1': [0.9503, 0.9506]}]\n", - "valid: 100%|██████████| 292/292 [00:01<00:00, 163.66it/s, loss=0.127, metrics={'Accuracy': [0.946, 0.9302], 'Precision': 0.9459, 'Recall': [0.946, 0.9302], 'F1': [0.9719, 0.2332]}] \n", - "epoch 3: 100%|██████████| 208/208 [00:01<00:00, 162.24it/s, loss=0.124, metrics={'Accuracy': [0.9524, 0.9462], 'Precision': 0.9493, 'Recall': [0.9524, 0.9462], 'F1': [0.9491, 0.9495]}]\n", - "valid: 100%|██████████| 292/292 [00:01<00:00, 161.53it/s, loss=0.127, metrics={'Accuracy': [0.9457, 0.9302], 'Precision': 0.9456, 'Recall': [0.9457, 0.9302], 'F1': [0.9718, 0.2323]}]\n" + "epoch 1: 0%| | 0/208 [00:00\n", " \n", " 0\n", - " 0.358859\n", - " [0.78240305, 0.86529005]\n", - " 0.8230473\n", - " [0.78240305, 0.86529005]\n", - " [0.8184067, 0.8274565]\n", - " 0.145092\n", - " [0.97480273, 0.85271317]\n", - " 0.9737221\n", - " [0.97480273, 0.85271317]\n", - " [0.9865836, 0.36484244]\n", + " 0.187643\n", + " [0.92486894, 0.9248898]\n", + " 0.92487943\n", + " [0.92486894, 0.9248898]\n", + " [0.9244203, 0.92533314]\n", + " 0.085667\n", + " [0.96635747, 0.9147287]\n", + " 0.96590054\n", + " [0.96635747, 0.9147287]\n", + " [0.98251045, 0.32196453]\n", " \n", " \n", " 1\n", - " 0.201815\n", - " [0.94106466, 0.90215266]\n", - " 0.9218901\n", - " [0.94106466, 0.90215266]\n", - " [0.92436975, 0.9192423]\n", - " 0.276596\n", - " [0.878167, 0.9612403]\n", - " 0.87890226\n", - " [0.878167, 0.9612403]\n", - " [0.9349596, 0.12319921]\n", + " 0.121485\n", + " [0.9521358, 0.9490831]\n", + " 0.9506268\n", + " [0.9521358, 0.9490831]\n", + " [0.95122886, 0.95000976]\n", + " 0.089378\n", + " [0.9613042, 0.9302326]\n", + " 0.9610292\n", + " [0.9613042, 0.9302326]\n", + " [0.9799591, 0.2970297]\n", " \n", " \n", " 2\n", - " 0.154379\n", - " [0.936721, 0.94180405]\n", - " 0.93924785\n", - " [0.936721, 0.94180405]\n", - " [0.9394231, 0.93907154]\n", - " 0.100671\n", - " [0.9632424, 0.8992248]\n", - " 0.9626758\n", - " [0.9632424, 0.8992248]\n", - " [0.9808275, 0.2989691]\n", + " 0.102038\n", + " [0.9534429, 0.96310383]\n", + " 0.95834136\n", + " [0.9534429, 0.96310383]\n", + " [0.9575638, 0.95909095]\n", + " 0.115516\n", + " [0.9437215, 0.9457364]\n", + " 0.9437393\n", + " [0.9437215, 0.9457364]\n", + " [0.97080404, 0.22932333]\n", " \n", " \n", "\n", "" ], "text/plain": [ - " train_loss train_Accuracy train_Precision \\\n", - "0 0.358859 [0.78240305, 0.86529005] 0.8230473 \n", - "1 0.201815 [0.94106466, 0.90215266] 0.9218901 \n", - "2 0.154379 [0.936721, 0.94180405] 0.93924785 \n", + " train_loss train_Accuracy train_Precision \\\n", + "0 0.187643 [0.92486894, 0.9248898] 0.92487943 \n", + "1 0.121485 [0.9521358, 0.9490831] 0.9506268 \n", + "2 0.102038 [0.9534429, 0.96310383] 0.95834136 \n", "\n", - " train_Recall train_F1 val_loss \\\n", - "0 [0.78240305, 0.86529005] [0.8184067, 0.8274565] 0.145092 \n", - "1 [0.94106466, 0.90215266] [0.92436975, 0.9192423] 0.276596 \n", - "2 [0.936721, 0.94180405] [0.9394231, 0.93907154] 0.100671 \n", + " train_Recall train_F1 val_loss \\\n", + "0 [0.92486894, 0.9248898] [0.9244203, 0.92533314] 0.085667 \n", + "1 [0.9521358, 0.9490831] [0.95122886, 0.95000976] 0.089378 \n", + "2 [0.9534429, 0.96310383] [0.9575638, 0.95909095] 0.115516 \n", "\n", - " val_Accuracy val_Precision val_Recall \\\n", - "0 [0.97480273, 0.85271317] 0.9737221 [0.97480273, 0.85271317] \n", - "1 [0.878167, 0.9612403] 0.87890226 [0.878167, 0.9612403] \n", - "2 [0.9632424, 0.8992248] 0.9626758 [0.9632424, 0.8992248] \n", + " val_Accuracy val_Precision val_Recall \\\n", + "0 [0.96635747, 0.9147287] 0.96590054 [0.96635747, 0.9147287] \n", + "1 [0.9613042, 0.9302326] 0.9610292 [0.9613042, 0.9302326] \n", + "2 [0.9437215, 0.9457364] 0.9437393 [0.9437215, 0.9457364] \n", "\n", - " val_F1 \n", - "0 [0.9865836, 0.36484244] \n", - "1 [0.9349596, 0.12319921] \n", - "2 [0.9808275, 0.2989691] " + " val_F1 \n", + "0 [0.98251045, 0.32196453] \n", + "1 [0.9799591, 0.2970297] \n", + "2 [0.97080404, 0.22932333] " ] }, - "execution_count": 21, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -989,14 +1003,14 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "predict: 100%|██████████| 292/292 [00:00<00:00, 368.56it/s]\n" + "predict: 100%|██████████| 292/292 [00:00<00:00, 306.28it/s]\n" ] }, { @@ -1005,15 +1019,15 @@ "text": [ " precision recall f1-score support\n", "\n", - " 0 1.00 0.96 0.98 14446\n", - " 1 0.18 0.92 0.30 130\n", + " 0 1.00 0.94 0.97 14446\n", + " 1 0.13 0.95 0.23 130\n", "\n", - " accuracy 0.96 14576\n", - " macro avg 0.59 0.94 0.64 14576\n", - "weighted avg 0.99 0.96 0.97 14576\n", + " accuracy 0.94 14576\n", + " macro avg 0.57 0.95 0.60 14576\n", + "weighted avg 0.99 0.94 0.96 14576\n", "\n", "Actual predicted values:\n", - "(array([0, 1]), array([13910, 666]))\n" + "(array([0, 1]), array([13650, 926]))\n" ] } ], diff --git a/examples/adult_census.py b/examples/adult_census.py index 6a34613b263cbd678bc459585b3ac1ac0d3ea4a1..0637c02a67113f41ea646f2d1fde3b0ccd5bd9c9 100644 --- a/examples/adult_census.py +++ b/examples/adult_census.py @@ -11,6 +11,8 @@ from pytorch_widedeep.models import ( # noqa: F401 TabResnet, ) from pytorch_widedeep.metrics import Accuracy, Precision + +# from torchmetrics import Accuracy as accuracy_score from pytorch_widedeep.callbacks import ( LRHistory, EarlyStopping, @@ -94,6 +96,7 @@ if __name__ == "__main__": schedulers = {"wide": wide_sch, "deeptabular": deep_sch} initializers = {"wide": KaimingNormal, "deeptabular": XavierNormal} callbacks = [early_stopping, model_checkpoint, LRHistory(n_epochs=10)] + # metrics = [Accuracy, accuracy_score(num_classes=2), Precision] metrics = [Accuracy, Precision] trainer = Trainer( diff --git a/pytorch_widedeep/callbacks.py b/pytorch_widedeep/callbacks.py index 869463bd15c0d2e62a2f20cda6c35ee83a657a56..4fd4a7c6f0bed4b100d1378338f80369d927e566 100644 --- a/pytorch_widedeep/callbacks.py +++ b/pytorch_widedeep/callbacks.py @@ -157,6 +157,8 @@ class History(Callback): ): logs = logs or {} for k, v in logs.items(): + if isinstance(v, np.ndarray): + v = v.tolist() self.trainer.history.setdefault(k, []).append(v) @@ -216,6 +218,12 @@ class LRShedulerCallback(Callback): class MetricCallback(Callback): + r"""Callback that resets the metrics (if any metric is used) + + This callback runs by default within :obj:`Trainer`, therefore, should not + be passed to the :obj:`Trainer`. Is included here just for completion. + """ + def __init__(self, container: MultipleMetrics): self.container = container diff --git a/pytorch_widedeep/dataloaders.py b/pytorch_widedeep/dataloaders.py index 76f51509cb911a3f4b367ac1949377937cbe9de6..df784a8c296c6491b127ef14ee61e4cf6ed25d99 100644 --- a/pytorch_widedeep/dataloaders.py +++ b/pytorch_widedeep/dataloaders.py @@ -34,23 +34,18 @@ class DataLoaderDefault(DataLoader): class DataLoaderImbalanced(DataLoader): - r"""Helper function to load and shuffle tensors into models in batches with - adjusted weights to "fight" against imbalance of the classes. If the - classes do not begin from 0 remapping is necessary, see: - https://towardsdatascience.com/pytorch-tabular-multiclass-classification-9f8211a123ab + r"""Class to load and shuffle batches with adjusted weights for imbalanced + datasets. If the classes do not begin from 0 remapping is necessary. See + `here `_ Parameters ---------- - dataset ``WideDeepDataset``: - dataset containing target classes in dataset.Y + dataset: ``WideDeepDataset`` + see ``pytorch_widedeep.training._wd_dataset`` batch_size: int size of batch num_workers: int number of workers - - Returns: - -------- - PyTorch ``DataLoader`` object """ def __init__( diff --git a/pytorch_widedeep/metrics.py b/pytorch_widedeep/metrics.py index 5cfb565d7acd9d934f73df56d660f64495c9e7ea..2486eb13ab8fa562becdc0c930deffa66a63a6ea 100644 --- a/pytorch_widedeep/metrics.py +++ b/pytorch_widedeep/metrics.py @@ -67,13 +67,13 @@ class Accuracy(Metric): >>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1) >>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1) >>> acc(y_pred, y_true) - 0.5 + array(0.5) >>> >>> acc = Accuracy(top_k=2) >>> y_true = torch.tensor([0, 1, 2]) >>> y_pred = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]]) >>> acc(y_pred, y_true) - 0.6666666666666666 + array(0.66666667) """ def __init__(self, top_k: int = 1): @@ -126,13 +126,13 @@ class Precision(Metric): >>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1) >>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1) >>> prec(y_pred, y_true) - 0.5 + array(0.5) >>> >>> prec = Precision(average=True) >>> y_true = torch.tensor([0, 1, 2]) >>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]]) >>> prec(y_pred, y_true) - 0.3333333432674408 + array(0.33333334) """ def __init__(self, average: bool = True): @@ -192,13 +192,13 @@ class Recall(Metric): >>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1) >>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1) >>> rec(y_pred, y_true) - 0.5 + array(0.5) >>> >>> rec = Recall(average=True) >>> y_true = torch.tensor([0, 1, 2]) >>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]]) >>> rec(y_pred, y_true) - 0.3333333432674408 + array(0.33333334) """ def __init__(self, average: bool = True): @@ -262,13 +262,13 @@ class FBetaScore(Metric): >>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1) >>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1) >>> fbeta(y_pred, y_true) - 0.5 + array(0.5) >>> >>> fbeta = FBetaScore(beta=2) >>> y_true = torch.tensor([0, 1, 2]) >>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]]) >>> fbeta(y_pred, y_true) - 0.3333333432674408 + array(0.33333334) """ def __init__(self, beta: int, average: bool = True): @@ -321,13 +321,13 @@ class F1Score(Metric): >>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1) >>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1) >>> f1(y_pred, y_true) - 0.5 + array(0.5) >>> >>> f1 = F1Score() >>> y_true = torch.tensor([0, 1, 2]) >>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]]) >>> f1(y_pred, y_true) - 0.3333333432674408 + array(0.33333334) """ def __init__(self, average: bool = True): @@ -367,7 +367,7 @@ class R2Score(Metric): >>> y_true = torch.tensor([3, -0.5, 2, 7]).view(-1, 1) >>> y_pred = torch.tensor([2.5, 0.0, 2, 8]).view(-1, 1) >>> r2(y_pred, y_true) - 0.9486081370449679 + array(0.94860814) """ def __init__(self): diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index 770f2f61a47e522d61c2a029156cc34836f1d14a..44c37d1aac213c72538ce0cc7851154a686c9204 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -85,7 +85,7 @@ class Trainer: function. See for example :class:`pytorch_widedeep.losses.FocalLoss` for the required structure of the object or the `Examples - `_ + `__ folder in the repo. .. note:: If ``custom_loss_function`` is not None, ``objective`` must be @@ -128,7 +128,7 @@ class Trainer: callbacks are used by default. This can also be a custom callback as long as the object of type ``Callback``. See :obj:`pytorch_widedeep.callbacks.Callback` or the `Examples - `_ + `__ folder in the repo metrics: List, optional, default=None - List of objects of type :obj:`Metric`. Metrics available are: @@ -136,7 +136,7 @@ class Trainer: ``F1Score`` and ``R2Score``. This can also be a custom metric as long as it is an object of type :obj:`Metric`. See :obj:`pytorch_widedeep.metrics.Metric` or the `Examples - `_ + `__ folder in the repo - List of objects of type :obj:`torchmetrics.Metric`. This can be any metric from torchmetrics library `Examples @@ -379,11 +379,10 @@ class Trainer: epochs validation frequency batch_size: int, default=32 batch size - custom_dataloader: ``torch.utils.data.DataLoader``, Optional, default=None + custom_dataloader: ``DataLoader``, Optional, default=None object of class ``torch.utils.data.DataLoader``. Available - predefined dataloaders are ``DataLoaderDefault``, ```` in - ``pytorch_widedeep.metrics``. If None, a ``DataLoaderDefault`` - is set. + predefined dataloaders are in ``pytorch-widedeep.dataloaders``.If + ``None``, a standard torch ``DataLoader`` is used. finetune: bool, default=False param alias: ``warmup`` @@ -415,7 +414,7 @@ class Trainer: For details on how these routines work, please see the Examples section in this documentation and the `Examples - `_ + `__ folder in the repo. finetune_epochs: int, default=4 param alias: ``warmup_epochs`` @@ -493,7 +492,7 @@ class Trainer: -------- For a series of comprehensive examples please, see the `Examples - `_ + `__ folder in the repo For completion, here we include some `"fabricated"` examples, i.e. @@ -772,7 +771,7 @@ class Trainer: -------- For a series of comprehensive examples please, see the `Examples - `_ + `__ folder in the repo For completion, here we include a `"fabricated"` example, i.e. @@ -859,7 +858,7 @@ class Trainer: save_state_dict: bool = False, model_filename: str = "wd_model.pt", ): - """Saves the model, training and evaluation history, and the + r"""Saves the model, training and evaluation history, and the ``feature_importance`` attribute (if the ``deeptabular`` component is a Tabnet model) to disk