diff --git a/examples/15_AUC_multiclass.ipynb b/examples/15_AUC_multiclass.ipynb index 02cb8ea7d9bab08971aeea7ac4f607954d718a88..f0af0e1a7c3d8d57590b250adc5b156903b0306a 100644 --- a/examples/15_AUC_multiclass.ipynb +++ b/examples/15_AUC_multiclass.ipynb @@ -26,8 +26,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2021-12-10 21:26:26.784170: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", - "2021-12-10 21:26:26.784221: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" + "2021-12-11 16:34:55.734357: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", + "2021-12-11 16:34:55.734404: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" ] } ], @@ -39,7 +39,7 @@ "from pytorch_widedeep import Trainer\n", "from pytorch_widedeep.preprocessing import TabPreprocessor\n", "from pytorch_widedeep.models import TabMlp, WideDeep\n", - "from torchmetrics import AUC\n", + "from torchmetrics import AUC, AUROC\n", "from pytorch_widedeep.initializers import XavierNormal\n", "from pytorch_widedeep.datasets import load_ecoli\n", "from pytorch_widedeep.utils import LabelEncoder\n", @@ -53,7 +53,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 2, "id": "07c75f0c", "metadata": {}, "outputs": [ @@ -163,7 +163,7 @@ "4 ADI_ECOLI 0.23 0.32 0.48 0.5 0.55 0.25 0.35 cp" ] }, - "execution_count": 26, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -175,7 +175,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 3, "id": "1e3f8efc", "metadata": {}, "outputs": [ @@ -193,7 +193,7 @@ "Name: class, dtype: int64" ] }, - "execution_count": 27, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -205,7 +205,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 4, "id": "e4db0d6d", "metadata": {}, "outputs": [], @@ -216,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 5, "id": "005531a3", "metadata": {}, "outputs": [], @@ -228,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 6, "id": "214b3071", "metadata": {}, "outputs": [], @@ -239,7 +239,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 7, "id": "168c81f1", "metadata": {}, "outputs": [], @@ -258,7 +258,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 8, "id": "3a7b246b", "metadata": {}, "outputs": [], @@ -268,7 +268,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 9, "id": "7a2dac24", "metadata": {}, "outputs": [], @@ -298,7 +298,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 10, "id": "511198d4", "metadata": {}, "outputs": [ @@ -331,7 +331,7 @@ ")" ] }, - "execution_count": 39, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -347,7 +347,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 11, "id": "a5359b0f", "metadata": {}, "outputs": [ @@ -356,49 +356,40 @@ "output_type": "stream", "text": [ "/home/palo/miniconda3/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AUC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.\n", + " warnings.warn(*args, **kwargs)\n", + "/home/palo/miniconda3/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AUROC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.\n", " warnings.warn(*args, **kwargs)\n" ] } ], "source": [ - "auc = AUC()" + "auc = AUC(reorder=True)\n", + "auc.num_classes = df_enc[\"class\"].nunique()\n", + "auroc = AUROC(num_classes=df_enc[\"class\"].nunique())" ] }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 12, "id": "34a18ac0", "metadata": { "scrolled": false }, "outputs": [ { - "ename": "AttributeError", - "evalue": "'AUC' object has no attribute 'num_classes'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_1193/1577559620.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 14\u001b[0m )\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_val\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mX_val\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/training/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X_wide, X_tab, X_text, X_img, X_train, X_val, val_split, target, n_epochs, validation_freq, batch_size, custom_dataloader, finetune, finetune_epochs, finetune_max_lr, finetune_deeptabular_gradual, finetune_deeptabular_max_lr, finetune_deeptabular_layers, finetune_deeptext_gradual, finetune_deeptext_max_lr, finetune_deeptext_layers, finetune_deepimage_gradual, finetune_deepimage_max_lr, finetune_deepimage_layers, finetune_routine, stop_after_finetuning, **kwargs)\u001b[0m\n\u001b[1;32m 615\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargett\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 616\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_description\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"epoch %i\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 617\u001b[0;31m \u001b[0mtrain_score\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargett\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 618\u001b[0m \u001b[0mprint_loss_and_metric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_score\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 619\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback_container\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/training/trainer.py\u001b[0m in \u001b[0;36m_train_step\u001b[0;34m(self, data, target, batch_idx)\u001b[0m\n\u001b[1;32m 1168\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1169\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1170\u001b[0;31m \u001b[0mscore\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1171\u001b[0m \u001b[0;31m# TODO raise exception if the loss is exploding with non scaled target values\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1172\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/training/trainer.py\u001b[0m in \u001b[0;36m_get_score\u001b[0;34m(self, y_pred, y)\u001b[0m\n\u001b[1;32m 1212\u001b[0m \u001b[0mscore\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1213\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmethod\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"multiclass\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1214\u001b[0;31m \u001b[0mscore\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1215\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mscore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1216\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/metrics.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, y_pred, y_true)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0mlogs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprefix\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_name\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTorchMetric\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_classes\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mround\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_classes\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# type: ignore[operator]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1128\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1129\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1130\u001b[0;31m raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0m\u001b[1;32m 1131\u001b[0m type(self).__name__, name))\n\u001b[1;32m 1132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAttributeError\u001b[0m: 'AUC' object has no attribute 'num_classes'" + "name": "stderr", + "output_type": "stream", + "text": [ + "epoch 1: 100%|██████████| 6/6 [00:00<00:00, 79.20it/s, loss=0.1, metrics={'AUC': 8.5, 'AUROC': 0.427}]\n", + "valid: 100%|██████████| 1/1 [00:00<00:00, 6.06it/s, loss=0.0961, metrics={'AUC': 6.5, 'AUROC': 0.419}]\n", + "epoch 2: 100%|██████████| 6/6 [00:00<00:00, 82.52it/s, loss=0.095, metrics={'AUC': 4.5, 'AUROC': 0.4418}]\n", + "valid: 100%|██████████| 1/1 [00:00<00:00, 5.69it/s, loss=0.0917, metrics={'AUC': 6.5, 'AUROC': 0.4351}]\n", + "epoch 3: 100%|██████████| 6/6 [00:00<00:00, 103.30it/s, loss=0.0908, metrics={'AUC': 5.5, 'AUROC': 0.4715}]\n", + "valid: 100%|██████████| 1/1 [00:00<00:00, 5.35it/s, loss=0.0875, metrics={'AUC': 6.5, 'AUROC': 0.4633}]\n", + "epoch 4: 100%|██████████| 6/6 [00:00<00:00, 90.88it/s, loss=0.0872, metrics={'AUC': 7.0, 'AUROC': 0.4767}]\n", + "valid: 100%|██████████| 1/1 [00:00<00:00, 5.44it/s, loss=0.0874, metrics={'AUC': 6.5, 'AUROC': 0.4652}]\n", + "epoch 5: 100%|██████████| 6/6 [00:00<00:00, 88.87it/s, loss=0.0866, metrics={'AUC': 6.0, 'AUROC': 0.4775}]\n", + "valid: 100%|██████████| 1/1 [00:00<00:00, 5.37it/s, loss=0.087, metrics={'AUC': 6.5, 'AUROC': 0.4524}]\n" ] } ], @@ -414,11 +405,10 @@ " lr_schedulers={\"deeptabular\": deep_sch},\n", " initializers={\"deeptabular\": XavierNormal},\n", " optimizers={\"deeptabular\": deep_opt},\n", - " metrics=[auc],\n", - " verbose=0,\n", + " metrics=[auc, auroc],\n", ")\n", "\n", - "trainer.fit(X_train=X_train, X_val=X_val, n_epochs=5, batch_size=10)" + "trainer.fit(X_train=X_train, X_val=X_val, n_epochs=5, batch_size=50)" ] } ], diff --git a/pytorch_widedeep/metrics.py b/pytorch_widedeep/metrics.py index 2486eb13ab8fa562becdc0c930deffa66a63a6ea..e939ed8011dd7f32a92f28c732cfb11a103ac2ed 100644 --- a/pytorch_widedeep/metrics.py +++ b/pytorch_widedeep/metrics.py @@ -1,6 +1,7 @@ import numpy as np import torch from torchmetrics import Metric as TorchMetric +from torchmetrics import AUC from .wdtypes import * # noqa: F403 @@ -38,10 +39,23 @@ class MultipleMetrics(object): if isinstance(metric, Metric): logs[self.prefix + metric._name] = metric(y_pred, y_true) if isinstance(metric, TorchMetric): + if not hasattr(metric, "num_classes"): + raise ValueError( + """TorchMetric does not have num_classes attribute. + Use metric in this library or extend the metric by num_classes attribute, + see `examples ` + """ + ) if metric.num_classes == 2: - metric.update(torch.round(y_pred).int(), y_true.int()) + if isinstance(metric, AUC): + metric.update(torch.round(y_pred).int(), y_true.int()) + else: + metric.update(y_pred, y_true.int()) if metric.num_classes > 2: # type: ignore[operator] - metric.update(torch.max(y_pred, dim=1).indices, y_true.int()) # type: ignore[attr-defined] + if isinstance(metric, AUC): + metric.update(torch.max(y_pred, dim=1).indices, y_true.int()) # type: ignore[attr-defined] + else: + metric.update(y_pred, y_true.int()) # type: ignore[attr-defined] logs[self.prefix + type(metric).__name__] = ( metric.compute().detach().cpu().numpy() ) @@ -396,3 +410,62 @@ class R2Score(Metric): y_true_avg = self.y_true_sum / self.num_examples self.denominator += ((y_true - y_true_avg) ** 2).sum().item() return np.array((1 - (self.numerator / self.denominator))) + + +class Accuracy(Metric): + r"""Class to calculate the accuracy for both binary and categorical problems + + Parameters + ---------- + top_k: int, default = 1 + Accuracy will be computed using the top k most likely classes in + multiclass problems + + Examples + -------- + >>> import torch + >>> + >>> from pytorch_widedeep.metrics import Accuracy + >>> + >>> acc = Accuracy() + >>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1) + >>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1) + >>> acc(y_pred, y_true) + array(0.5) + >>> + >>> acc = Accuracy(top_k=2) + >>> y_true = torch.tensor([0, 1, 2]) + >>> y_pred = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]]) + >>> acc(y_pred, y_true) + array(0.66666667) + """ + + def __init__(self, top_k: int = 1): + super(Accuracy, self).__init__() + + self.top_k = top_k + self.correct_count = 0 + self.total_count = 0 + self._name = "acc" + + def reset(self): + """ + resets counters to 0 + """ + self.correct_count = 0 + self.total_count = 0 + + def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: + num_classes = y_pred.size(1) + + if num_classes == 1: + y_pred = y_pred.round() + y_true = y_true + elif num_classes > 1: + y_pred = y_pred.topk(self.top_k, 1)[1] + y_true = y_true.view(-1, 1).expand_as(y_pred) + + self.correct_count += y_pred.eq(y_true).sum().item() # type: ignore[assignment] + self.total_count += len(y_pred) + accuracy = float(self.correct_count) / float(self.total_count) + return np.array(accuracy) diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index 4d32c829cff0c7a45fa961941306ad1b38c8f116..db9004a0cc6053d6da48745356f4e59c8aa6648d 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -147,10 +147,14 @@ class Trainer: `__ folder in the repo - List of objects of type :obj:`torchmetrics.Metric`. This can be any - metric from torchmetrics library `Examples + metric from torchmetrics library that has attribute num_classes `Examples `_. This can also be a custom metric as - long as it is an object of type :obj:`Metric`. See `the instructions + classification-metrics>`_. + Objects of type :obj:`torchmetrics.Metric` can be extended with num_classes + attribute to be used with the Trainer object, see `examples + `. + This can also be a custom metric as long as it is an object of + type :obj:`Metric`. See `the instructions `_. class_weight: float, List or Tuple. optional. default=None - float indicating the weight of the minority class in binary classification diff --git a/tests/test_metrics/test_torchmetrics.py b/tests/test_metrics/test_torchmetrics.py index a5bd69ae2aa375db8df16f493a4e21100799d978..c28495f45300928f670998202016ac21565cf934 100644 --- a/tests/test_metrics/test_torchmetrics.py +++ b/tests/test_metrics/test_torchmetrics.py @@ -1,13 +1,14 @@ import numpy as np import torch import pytest -from torchmetrics import F1, FBeta, Recall, Accuracy, Precision +from torchmetrics import F1, FBeta, Recall, Accuracy, Precision, AUC from sklearn.metrics import ( f1_score, fbeta_score, recall_score, accuracy_score, precision_score, + auc_score, ) from pytorch_widedeep.metrics import MultipleMetrics @@ -35,9 +36,12 @@ y_pred_bin_pt = torch.from_numpy(y_pred_bin_np) ("Recall", recall_score, Recall(num_classes=2, average="none")), ("F1", f1_score, F1(num_classes=2, average="none")), ("FBeta", f2_score_bin, FBeta(beta=2, num_classes=2, average="none")), + ("AUC", auc_score, AUC()), ], ) def test_binary_metrics(metric_name, sklearn_metric, torch_metric): + if metric_name == "AUC": + torch_metric.num_classes=2 sk_res = sklearn_metric(y_true_bin_np, y_pred_bin_np.round()) wd_metric = MultipleMetrics(metrics=[torch_metric]) wd_logs = wd_metric(y_pred_bin_pt, y_true_bin_pt) @@ -82,11 +86,14 @@ def f2_score_multi(y_true, y_pred, average): ("Recall", recall_score, Recall(num_classes=3, average="macro")), ("F1", f1_score, F1(num_classes=3, average="macro")), ("FBeta", f2_score_multi, FBeta(beta=3, num_classes=3, average="macro")), + ("AUC", auc_score, AUC()), ], ) def test_muticlass_metrics(metric_name, sklearn_metric, torch_metric): if metric_name == "Accuracy": sk_res = sklearn_metric(y_true_multi_np, y_pred_muli_np.argmax(axis=1)) + elif metric_name == "AUC": + torch_metric.num_classes=3 else: sk_res = sklearn_metric( y_true_multi_np, y_pred_muli_np.argmax(axis=1), average="macro"