提交 d13e86f9 编写于 作者: P Pavol Mulinka

fixed multiclass torchmetrics

上级 ef9ea277
......@@ -26,8 +26,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2021-12-10 21:26:26.784170: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
"2021-12-10 21:26:26.784221: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
"2021-12-11 16:34:55.734357: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
"2021-12-11 16:34:55.734404: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
]
}
],
......@@ -39,7 +39,7 @@
"from pytorch_widedeep import Trainer\n",
"from pytorch_widedeep.preprocessing import TabPreprocessor\n",
"from pytorch_widedeep.models import TabMlp, WideDeep\n",
"from torchmetrics import AUC\n",
"from torchmetrics import AUC, AUROC\n",
"from pytorch_widedeep.initializers import XavierNormal\n",
"from pytorch_widedeep.datasets import load_ecoli\n",
"from pytorch_widedeep.utils import LabelEncoder\n",
......@@ -53,7 +53,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 2,
"id": "07c75f0c",
"metadata": {},
"outputs": [
......@@ -163,7 +163,7 @@
"4 ADI_ECOLI 0.23 0.32 0.48 0.5 0.55 0.25 0.35 cp"
]
},
"execution_count": 26,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
......@@ -175,7 +175,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 3,
"id": "1e3f8efc",
"metadata": {},
"outputs": [
......@@ -193,7 +193,7 @@
"Name: class, dtype: int64"
]
},
"execution_count": 27,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
......@@ -205,7 +205,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 4,
"id": "e4db0d6d",
"metadata": {},
"outputs": [],
......@@ -216,7 +216,7 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 5,
"id": "005531a3",
"metadata": {},
"outputs": [],
......@@ -228,7 +228,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 6,
"id": "214b3071",
"metadata": {},
"outputs": [],
......@@ -239,7 +239,7 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 7,
"id": "168c81f1",
"metadata": {},
"outputs": [],
......@@ -258,7 +258,7 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 8,
"id": "3a7b246b",
"metadata": {},
"outputs": [],
......@@ -268,7 +268,7 @@
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": 9,
"id": "7a2dac24",
"metadata": {},
"outputs": [],
......@@ -298,7 +298,7 @@
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": 10,
"id": "511198d4",
"metadata": {},
"outputs": [
......@@ -331,7 +331,7 @@
")"
]
},
"execution_count": 39,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
......@@ -347,7 +347,7 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 11,
"id": "a5359b0f",
"metadata": {},
"outputs": [
......@@ -356,49 +356,40 @@
"output_type": "stream",
"text": [
"/home/palo/miniconda3/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AUC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.\n",
" warnings.warn(*args, **kwargs)\n",
"/home/palo/miniconda3/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AUROC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.\n",
" warnings.warn(*args, **kwargs)\n"
]
}
],
"source": [
"auc = AUC()"
"auc = AUC(reorder=True)\n",
"auc.num_classes = df_enc[\"class\"].nunique()\n",
"auroc = AUROC(num_classes=df_enc[\"class\"].nunique())"
]
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 12,
"id": "34a18ac0",
"metadata": {
"scrolled": false
},
"outputs": [
{
"ename": "AttributeError",
"evalue": "'AUC' object has no attribute 'num_classes'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_1193/1577559620.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 14\u001b[0m )\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_val\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mX_val\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/training/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X_wide, X_tab, X_text, X_img, X_train, X_val, val_split, target, n_epochs, validation_freq, batch_size, custom_dataloader, finetune, finetune_epochs, finetune_max_lr, finetune_deeptabular_gradual, finetune_deeptabular_max_lr, finetune_deeptabular_layers, finetune_deeptext_gradual, finetune_deeptext_max_lr, finetune_deeptext_layers, finetune_deepimage_gradual, finetune_deepimage_max_lr, finetune_deepimage_layers, finetune_routine, stop_after_finetuning, **kwargs)\u001b[0m\n\u001b[1;32m 615\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargett\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 616\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_description\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"epoch %i\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 617\u001b[0;31m \u001b[0mtrain_score\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargett\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 618\u001b[0m \u001b[0mprint_loss_and_metric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_score\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 619\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback_container\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/training/trainer.py\u001b[0m in \u001b[0;36m_train_step\u001b[0;34m(self, data, target, batch_idx)\u001b[0m\n\u001b[1;32m 1168\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1169\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1170\u001b[0;31m \u001b[0mscore\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1171\u001b[0m \u001b[0;31m# TODO raise exception if the loss is exploding with non scaled target values\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1172\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/training/trainer.py\u001b[0m in \u001b[0;36m_get_score\u001b[0;34m(self, y_pred, y)\u001b[0m\n\u001b[1;32m 1212\u001b[0m \u001b[0mscore\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1213\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmethod\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"multiclass\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1214\u001b[0;31m \u001b[0mscore\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1215\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mscore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1216\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/metrics.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, y_pred, y_true)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0mlogs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprefix\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_name\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTorchMetric\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_classes\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mround\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_classes\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# type: ignore[operator]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1128\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1129\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1130\u001b[0;31m raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0m\u001b[1;32m 1131\u001b[0m type(self).__name__, name))\n\u001b[1;32m 1132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAttributeError\u001b[0m: 'AUC' object has no attribute 'num_classes'"
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 6/6 [00:00<00:00, 79.20it/s, loss=0.1, metrics={'AUC': 8.5, 'AUROC': 0.427}]\n",
"valid: 100%|██████████| 1/1 [00:00<00:00, 6.06it/s, loss=0.0961, metrics={'AUC': 6.5, 'AUROC': 0.419}]\n",
"epoch 2: 100%|██████████| 6/6 [00:00<00:00, 82.52it/s, loss=0.095, metrics={'AUC': 4.5, 'AUROC': 0.4418}]\n",
"valid: 100%|██████████| 1/1 [00:00<00:00, 5.69it/s, loss=0.0917, metrics={'AUC': 6.5, 'AUROC': 0.4351}]\n",
"epoch 3: 100%|██████████| 6/6 [00:00<00:00, 103.30it/s, loss=0.0908, metrics={'AUC': 5.5, 'AUROC': 0.4715}]\n",
"valid: 100%|██████████| 1/1 [00:00<00:00, 5.35it/s, loss=0.0875, metrics={'AUC': 6.5, 'AUROC': 0.4633}]\n",
"epoch 4: 100%|██████████| 6/6 [00:00<00:00, 90.88it/s, loss=0.0872, metrics={'AUC': 7.0, 'AUROC': 0.4767}]\n",
"valid: 100%|██████████| 1/1 [00:00<00:00, 5.44it/s, loss=0.0874, metrics={'AUC': 6.5, 'AUROC': 0.4652}]\n",
"epoch 5: 100%|██████████| 6/6 [00:00<00:00, 88.87it/s, loss=0.0866, metrics={'AUC': 6.0, 'AUROC': 0.4775}]\n",
"valid: 100%|██████████| 1/1 [00:00<00:00, 5.37it/s, loss=0.087, metrics={'AUC': 6.5, 'AUROC': 0.4524}]\n"
]
}
],
......@@ -414,11 +405,10 @@
" lr_schedulers={\"deeptabular\": deep_sch},\n",
" initializers={\"deeptabular\": XavierNormal},\n",
" optimizers={\"deeptabular\": deep_opt},\n",
" metrics=[auc],\n",
" verbose=0,\n",
" metrics=[auc, auroc],\n",
")\n",
"\n",
"trainer.fit(X_train=X_train, X_val=X_val, n_epochs=5, batch_size=10)"
"trainer.fit(X_train=X_train, X_val=X_val, n_epochs=5, batch_size=50)"
]
}
],
......
import numpy as np
import torch
from torchmetrics import Metric as TorchMetric
from torchmetrics import AUC
from .wdtypes import * # noqa: F403
......@@ -38,10 +39,23 @@ class MultipleMetrics(object):
if isinstance(metric, Metric):
logs[self.prefix + metric._name] = metric(y_pred, y_true)
if isinstance(metric, TorchMetric):
if not hasattr(metric, "num_classes"):
raise ValueError(
"""TorchMetric does not have num_classes attribute.
Use metric in this library or extend the metric by num_classes attribute,
see `examples <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`
"""
)
if metric.num_classes == 2:
metric.update(torch.round(y_pred).int(), y_true.int())
if isinstance(metric, AUC):
metric.update(torch.round(y_pred).int(), y_true.int())
else:
metric.update(y_pred, y_true.int())
if metric.num_classes > 2: # type: ignore[operator]
metric.update(torch.max(y_pred, dim=1).indices, y_true.int()) # type: ignore[attr-defined]
if isinstance(metric, AUC):
metric.update(torch.max(y_pred, dim=1).indices, y_true.int()) # type: ignore[attr-defined]
else:
metric.update(y_pred, y_true.int()) # type: ignore[attr-defined]
logs[self.prefix + type(metric).__name__] = (
metric.compute().detach().cpu().numpy()
)
......@@ -396,3 +410,62 @@ class R2Score(Metric):
y_true_avg = self.y_true_sum / self.num_examples
self.denominator += ((y_true - y_true_avg) ** 2).sum().item()
return np.array((1 - (self.numerator / self.denominator)))
class Accuracy(Metric):
r"""Class to calculate the accuracy for both binary and categorical problems
Parameters
----------
top_k: int, default = 1
Accuracy will be computed using the top k most likely classes in
multiclass problems
Examples
--------
>>> import torch
>>>
>>> from pytorch_widedeep.metrics import Accuracy
>>>
>>> acc = Accuracy()
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> acc(y_pred, y_true)
array(0.5)
>>>
>>> acc = Accuracy(top_k=2)
>>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> acc(y_pred, y_true)
array(0.66666667)
"""
def __init__(self, top_k: int = 1):
super(Accuracy, self).__init__()
self.top_k = top_k
self.correct_count = 0
self.total_count = 0
self._name = "acc"
def reset(self):
"""
resets counters to 0
"""
self.correct_count = 0
self.total_count = 0
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
num_classes = y_pred.size(1)
if num_classes == 1:
y_pred = y_pred.round()
y_true = y_true
elif num_classes > 1:
y_pred = y_pred.topk(self.top_k, 1)[1]
y_true = y_true.view(-1, 1).expand_as(y_pred)
self.correct_count += y_pred.eq(y_true).sum().item() # type: ignore[assignment]
self.total_count += len(y_pred)
accuracy = float(self.correct_count) / float(self.total_count)
return np.array(accuracy)
......@@ -147,10 +147,14 @@ class Trainer:
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
folder in the repo
- List of objects of type :obj:`torchmetrics.Metric`. This can be any
metric from torchmetrics library `Examples
metric from torchmetrics library that has attribute num_classes `Examples
<https://torchmetrics.readthedocs.io/en/latest/references/modules.html#
classification-metrics>`_. This can also be a custom metric as
long as it is an object of type :obj:`Metric`. See `the instructions
classification-metrics>`_.
Objects of type :obj:`torchmetrics.Metric` can be extended with num_classes
attribute to be used with the Trainer object, see `examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`.
This can also be a custom metric as long as it is an object of
type :obj:`Metric`. See `the instructions
<https://torchmetrics.readthedocs.io/en/latest/>`_.
class_weight: float, List or Tuple. optional. default=None
- float indicating the weight of the minority class in binary classification
......
import numpy as np
import torch
import pytest
from torchmetrics import F1, FBeta, Recall, Accuracy, Precision
from torchmetrics import F1, FBeta, Recall, Accuracy, Precision, AUC
from sklearn.metrics import (
f1_score,
fbeta_score,
recall_score,
accuracy_score,
precision_score,
auc_score,
)
from pytorch_widedeep.metrics import MultipleMetrics
......@@ -35,9 +36,12 @@ y_pred_bin_pt = torch.from_numpy(y_pred_bin_np)
("Recall", recall_score, Recall(num_classes=2, average="none")),
("F1", f1_score, F1(num_classes=2, average="none")),
("FBeta", f2_score_bin, FBeta(beta=2, num_classes=2, average="none")),
("AUC", auc_score, AUC()),
],
)
def test_binary_metrics(metric_name, sklearn_metric, torch_metric):
if metric_name == "AUC":
torch_metric.num_classes=2
sk_res = sklearn_metric(y_true_bin_np, y_pred_bin_np.round())
wd_metric = MultipleMetrics(metrics=[torch_metric])
wd_logs = wd_metric(y_pred_bin_pt, y_true_bin_pt)
......@@ -82,11 +86,14 @@ def f2_score_multi(y_true, y_pred, average):
("Recall", recall_score, Recall(num_classes=3, average="macro")),
("F1", f1_score, F1(num_classes=3, average="macro")),
("FBeta", f2_score_multi, FBeta(beta=3, num_classes=3, average="macro")),
("AUC", auc_score, AUC()),
],
)
def test_muticlass_metrics(metric_name, sklearn_metric, torch_metric):
if metric_name == "Accuracy":
sk_res = sklearn_metric(y_true_multi_np, y_pred_muli_np.argmax(axis=1))
elif metric_name == "AUC":
torch_metric.num_classes=3
else:
sk_res = sklearn_metric(
y_true_multi_np, y_pred_muli_np.argmax(axis=1), average="macro"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册