未验证 提交 65465a45 编写于 作者: J Javier 提交者: GitHub

Merge pull request #17 from jrzaurin/precision_recall

Precision recall
...@@ -40,7 +40,9 @@ final output neuron or neurons, depending on whether we are performing a ...@@ -40,7 +40,9 @@ final output neuron or neurons, depending on whether we are performing a
binary classification or regression, or a multi-class classification. The binary classification or regression, or a multi-class classification. The
components within the faded-pink rectangles are concatenated. components within the faded-pink rectangles are concatenated.
In math terms, and following the notation in the [paper](https://arxiv.org/abs/1606.07792), Architecture 1 can be formulated as: In math terms, and following the notation in the
[paper](https://arxiv.org/abs/1606.07792), Architecture 1 can be formulated
as:
<p align="center"> <p align="center">
<img width="500" src="docs/figures/architecture_1_math.png"> <img width="500" src="docs/figures/architecture_1_math.png">
...@@ -130,7 +132,7 @@ from sklearn.model_selection import train_test_split ...@@ -130,7 +132,7 @@ from sklearn.model_selection import train_test_split
from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor
from pytorch_widedeep.models import Wide, DeepDense, WideDeep from pytorch_widedeep.models import Wide, DeepDense, WideDeep
from pytorch_widedeep.metrics import BinaryAccuracy from pytorch_widedeep.metrics import Accuracy
# these next 4 lines are not directly related to pytorch-widedeep. I assume # these next 4 lines are not directly related to pytorch-widedeep. I assume
# you have downloaded the dataset and place it in a dir called data/adult/ # you have downloaded the dataset and place it in a dir called data/adult/
...@@ -178,7 +180,7 @@ deepdense = DeepDense( ...@@ -178,7 +180,7 @@ deepdense = DeepDense(
# build, compile and fit # build, compile and fit
model = WideDeep(wide=wide, deepdense=deepdense) model = WideDeep(wide=wide, deepdense=deepdense)
model.compile(method="binary", metrics=[BinaryAccuracy]) model.compile(method="binary", metrics=[Accuracy])
model.fit( model.fit(
X_wide=X_wide, X_wide=X_wide,
X_deep=X_deep, X_deep=X_deep,
......
0.4.1 0.4.2
\ No newline at end of file \ No newline at end of file
# sort imports # sort imports
isort --recursive . pytorch_widedeep tests examples setup.py isort . pytorch_widedeep tests examples setup.py
# Black code style # Black code style
black . pytorch_widedeep tests examples setup.py black . pytorch_widedeep tests examples setup.py
# flake8 standards # flake8 standards
......
pytorch-widedeep Examples
*****************************
This section provides links to example notebooks that may be helpful to better
understand the functionalities withing ``pytorch-widedeep`` and how to use
them to address different problems
* `Preprocessors and Utils <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/01_Preprocessors_and_utils.ipynb>`__
* `Model Components <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/02_Model_Components.ipynb>`__
* `Binary Classification with default parameters <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/03_Binary_Classification_with_Defaults.ipynb>`__
* `Binary Classification with varying parameters <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/04_Binary_Classification_Varying_Parameters.ipynb>`__
* `Regression with Images and Text <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/05_Regression_with_Images_and_Text.ipynb>`__
* `Warm up routines <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/06_WarmUp_Model_Components.ipynb>`__
...@@ -19,6 +19,7 @@ Documentation ...@@ -19,6 +19,7 @@ Documentation
Preprocessing <preprocessing> Preprocessing <preprocessing>
Model Components <model_components> Model Components <model_components>
Wide and Deep Models <wide_deep/index> Wide and Deep Models <wide_deep/index>
Examples <examples>
Introduction Introduction
......
...@@ -30,7 +30,7 @@ Prepare the wide and deep columns ...@@ -30,7 +30,7 @@ Prepare the wide and deep columns
from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor
from pytorch_widedeep.models import Wide, DeepDense, WideDeep from pytorch_widedeep.models import Wide, DeepDense, WideDeep
from pytorch_widedeep.metrics import BinaryAccuracy from pytorch_widedeep.metrics import Accuracy
# prepare wide, crossed, embedding and continuous columns # prepare wide, crossed, embedding and continuous columns
wide_cols = [ wide_cols = [
...@@ -83,7 +83,7 @@ Build, compile, fit and predict ...@@ -83,7 +83,7 @@ Build, compile, fit and predict
# build, compile and fit # build, compile and fit
model = WideDeep(wide=wide, deepdense=deepdense) model = WideDeep(wide=wide, deepdense=deepdense)
model.compile(method="binary", metrics=[BinaryAccuracy]) model.compile(method="binary", metrics=[Accuracy])
model.fit( model.fit(
X_wide=X_wide, X_wide=X_wide,
X_deep=X_deep, X_deep=X_deep,
......
Metrics Metrics
======= =======
.. autoclass:: pytorch_widedeep.metrics.BinaryAccuracy .. autoclass:: pytorch_widedeep.metrics.Accuracy
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
.. autoclass:: pytorch_widedeep.metrics.CategoricalAccuracy .. autoclass:: pytorch_widedeep.metrics.Precision
:members:
:undoc-members:
:show-inheritance:
.. autoclass:: pytorch_widedeep.metrics.Recall
:members:
:undoc-members:
:show-inheritance:
.. autoclass:: pytorch_widedeep.metrics.FBetaScore
:members:
:undoc-members:
:show-inheritance:
.. autoclass:: pytorch_widedeep.metrics.F1Score
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
...@@ -170,11 +170,11 @@ ...@@ -170,11 +170,11 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"tensor([[-0.0000, -1.0061, -0.0000, -0.9828, -0.0000, -0.0000, -0.9944, -1.0133],\n", "tensor([[-0.0000, -0.9949, 3.8273, 0.0000, -1.3889, -2.9641, 0.0000, -0.0000],\n",
" [-0.0000, -0.9996, 0.0000, -1.0374, 0.0000, -0.0000, -1.0313, -0.0000],\n", " [ 3.9123, -0.0000, -0.0000, 1.9555, -1.3561, 1.7069, -0.0000, 0.9275],\n",
" [-0.8576, -1.0017, -0.0000, -0.9881, -0.0000, 0.0000, -0.0000, -0.0000],\n", " [-0.0000, -0.0000, 0.0000, -0.0000, 0.0000, -1.6489, -0.0000, -1.4985],\n",
" [ 3.9816, 0.0000, 0.0000, 0.0000, 3.7309, 1.1728, 0.0000, -1.1160],\n", " [-1.2736, 0.0000, -1.2819, 2.1232, 0.0000, 2.2767, -0.0000, 3.5354],\n",
" [-1.1339, -0.9925, -0.0000, -0.0000, -0.0000, 0.0000, -0.9638, 0.0000]],\n", " [-0.1726, -0.0000, -1.3275, -0.0000, -1.3703, 0.0000, -0.0000, -1.4637]],\n",
" grad_fn=<MulBackward0>)" " grad_fn=<MulBackward0>)"
] ]
}, },
...@@ -484,10 +484,10 @@ ...@@ -484,10 +484,10 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"tensor([[-1.4630e-04, -6.1540e-04, -2.4541e-04, 2.7543e-01, 1.2993e-01,\n", "tensor([[-2.2825e-03, -8.3100e-04, -8.8423e-04, -1.1084e-04, 8.8529e-02,\n",
" -1.6553e-03, 6.7002e-02, 2.3974e-01],\n", " -5.1577e-04, 2.8343e-01, -1.7071e-03],\n",
" [-9.9619e-04, -1.9412e-03, 1.2113e-01, 1.0122e-01, 2.9080e-01,\n", " [-1.8486e-03, -8.5602e-04, -1.8552e-03, 3.6481e-01, 9.0812e-02,\n",
" -2.0852e-03, -1.8016e-04, 2.7996e-02]], grad_fn=<LeakyReluBackward1>)" " -9.6603e-04, 3.9017e-01, -2.6355e-03]], grad_fn=<LeakyReluBackward1>)"
] ]
}, },
"execution_count": 18, "execution_count": 18,
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -21,12 +21,12 @@ ...@@ -21,12 +21,12 @@
"\n", "\n",
"from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor\n", "from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor\n",
"from pytorch_widedeep.models import Wide, DeepDense, WideDeep\n", "from pytorch_widedeep.models import Wide, DeepDense, WideDeep\n",
"from pytorch_widedeep.metrics import BinaryAccuracy" "from pytorch_widedeep.metrics import Accuracy, Precision"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -185,7 +185,7 @@ ...@@ -185,7 +185,7 @@
"4 30 United-States <=50K " "4 30 United-States <=50K "
] ]
}, },
"execution_count": 6, "execution_count": 2,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -197,7 +197,7 @@ ...@@ -197,7 +197,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -356,7 +356,7 @@ ...@@ -356,7 +356,7 @@
"4 30 United-States 0 " "4 30 United-States 0 "
] ]
}, },
"execution_count": 7, "execution_count": 3,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -381,7 +381,7 @@ ...@@ -381,7 +381,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -394,7 +394,7 @@ ...@@ -394,7 +394,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -412,7 +412,7 @@ ...@@ -412,7 +412,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -437,7 +437,7 @@ ...@@ -437,7 +437,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -475,7 +475,7 @@ ...@@ -475,7 +475,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -489,7 +489,7 @@ ...@@ -489,7 +489,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -527,7 +527,7 @@ ...@@ -527,7 +527,7 @@
")" ")"
] ]
}, },
"execution_count": 15, "execution_count": 9,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -560,16 +560,16 @@ ...@@ -560,16 +560,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"model.compile(method='binary', metrics=[BinaryAccuracy])" "model.compile(method='binary', metrics=[Accuracy, Precision])"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 11,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -591,16 +591,16 @@ ...@@ -591,16 +591,16 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 153/153 [00:02<00:00, 64.79it/s, loss=0.435, metrics={'acc': 0.7901}]\n", "epoch 1: 100%|██████████| 153/153 [00:01<00:00, 102.41it/s, loss=0.585, metrics={'acc': 0.7512, 'prec': 0.1818}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 124.97it/s, loss=0.358, metrics={'acc': 0.799}]\n", "valid: 100%|██████████| 39/39 [00:00<00:00, 98.78it/s, loss=0.513, metrics={'acc': 0.754, 'prec': 0.2429}] \n",
"epoch 2: 100%|██████████| 153/153 [00:02<00:00, 71.36it/s, loss=0.352, metrics={'acc': 0.8352}]\n", "epoch 2: 100%|██████████| 153/153 [00:01<00:00, 117.30it/s, loss=0.481, metrics={'acc': 0.782, 'prec': 0.8287}] \n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 124.33it/s, loss=0.349, metrics={'acc': 0.8358}]\n", "valid: 100%|██████████| 39/39 [00:00<00:00, 106.49it/s, loss=0.454, metrics={'acc': 0.7866, 'prec': 0.8245}]\n",
"epoch 3: 100%|██████████| 153/153 [00:02<00:00, 72.24it/s, loss=0.345, metrics={'acc': 0.8383}]\n", "epoch 3: 100%|██████████| 153/153 [00:01<00:00, 124.78it/s, loss=0.44, metrics={'acc': 0.8055, 'prec': 0.781}] \n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 121.07it/s, loss=0.345, metrics={'acc': 0.8389}]\n", "valid: 100%|██████████| 39/39 [00:00<00:00, 115.36it/s, loss=0.425, metrics={'acc': 0.8077, 'prec': 0.7818}]\n",
"epoch 4: 100%|██████████| 153/153 [00:02<00:00, 70.39it/s, loss=0.341, metrics={'acc': 0.8404}]\n", "epoch 4: 100%|██████████| 153/153 [00:01<00:00, 125.01it/s, loss=0.418, metrics={'acc': 0.814, 'prec': 0.7661}] \n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 123.29it/s, loss=0.343, metrics={'acc': 0.8406}]\n", "valid: 100%|██████████| 39/39 [00:00<00:00, 114.92it/s, loss=0.408, metrics={'acc': 0.8149, 'prec': 0.7671}]\n",
"epoch 5: 100%|██████████| 153/153 [00:02<00:00, 71.14it/s, loss=0.339, metrics={'acc': 0.8423}]\n", "epoch 5: 100%|██████████| 153/153 [00:01<00:00, 116.57it/s, loss=0.404, metrics={'acc': 0.819, 'prec': 0.7527}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 121.12it/s, loss=0.342, metrics={'acc': 0.8426}]\n" "valid: 100%|██████████| 39/39 [00:00<00:00, 108.89it/s, loss=0.397, metrics={'acc': 0.8203, 'prec': 0.7547}]\n"
] ]
} }
], ],
......
...@@ -43,7 +43,7 @@ ...@@ -43,7 +43,7 @@
"\n", "\n",
"from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor\n", "from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor\n",
"from pytorch_widedeep.models import Wide, DeepDense, WideDeep\n", "from pytorch_widedeep.models import Wide, DeepDense, WideDeep\n",
"from pytorch_widedeep.metrics import BinaryAccuracy" "from pytorch_widedeep.metrics import Accuracy"
] ]
}, },
{ {
...@@ -273,7 +273,7 @@ ...@@ -273,7 +273,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"model.compile(method='binary', metrics=[BinaryAccuracy])" "model.compile(method='binary', metrics=[Accuracy])"
] ]
}, },
{ {
...@@ -307,11 +307,11 @@ ...@@ -307,11 +307,11 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 153/153 [00:01<00:00, 118.14it/s, loss=0.475, metrics={'acc': 0.7854}]\n", "epoch 1: 100%|██████████| 153/153 [00:01<00:00, 127.54it/s, loss=0.476, metrics={'acc': 0.7808972948071559}]\n",
"epoch 2: 100%|██████████| 153/153 [00:00<00:00, 154.41it/s, loss=0.373, metrics={'acc': 0.8069}]\n", "epoch 2: 100%|██████████| 153/153 [00:01<00:00, 126.88it/s, loss=0.373, metrics={'acc': 0.8048268625393494}]\n",
"epoch 3: 100%|██████████| 153/153 [00:00<00:00, 153.93it/s, loss=0.365, metrics={'acc': 0.8151}]\n", "epoch 3: 100%|██████████| 153/153 [00:01<00:00, 141.92it/s, loss=0.365, metrics={'acc': 0.8136820822562895}]\n",
"epoch 4: 100%|██████████| 153/153 [00:00<00:00, 154.42it/s, loss=0.362, metrics={'acc': 0.8193}]\n", "epoch 4: 100%|██████████| 153/153 [00:01<00:00, 151.56it/s, loss=0.362, metrics={'acc': 0.8182312594374632}]\n",
"epoch 5: 100%|██████████| 153/153 [00:01<00:00, 147.62it/s, loss=0.36, metrics={'acc': 0.8219}]\n", "epoch 5: 100%|██████████| 153/153 [00:00<00:00, 158.22it/s, loss=0.36, metrics={'acc': 0.8210477823561027}]\n",
" 0%| | 0/153 [00:00<?, ?it/s]" " 0%| | 0/153 [00:00<?, ?it/s]"
] ]
}, },
...@@ -326,11 +326,11 @@ ...@@ -326,11 +326,11 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 153/153 [00:02<00:00, 75.79it/s, loss=0.392, metrics={'acc': 0.8209}]\n", "epoch 1: 100%|██████████| 153/153 [00:01<00:00, 78.65it/s, loss=0.397, metrics={'acc': 0.8198073691125158}]\n",
"epoch 2: 100%|██████████| 153/153 [00:01<00:00, 76.97it/s, loss=0.35, metrics={'acc': 0.823}] \n", "epoch 2: 100%|██████████| 153/153 [00:02<00:00, 75.69it/s, loss=0.348, metrics={'acc': 0.8221936229255862}]\n",
"epoch 3: 100%|██████████| 153/153 [00:02<00:00, 75.86it/s, loss=0.344, metrics={'acc': 0.8251}]\n", "epoch 3: 100%|██████████| 153/153 [00:02<00:00, 74.79it/s, loss=0.343, metrics={'acc': 0.8243576126737133}]\n",
"epoch 4: 100%|██████████| 153/153 [00:01<00:00, 76.79it/s, loss=0.34, metrics={'acc': 0.8269}] \n", "epoch 4: 100%|██████████| 153/153 [00:01<00:00, 76.79it/s, loss=0.338, metrics={'acc': 0.8264502057402526}]\n",
"epoch 5: 100%|██████████| 153/153 [00:01<00:00, 77.39it/s, loss=0.335, metrics={'acc': 0.8286}]\n", "epoch 5: 100%|██████████| 153/153 [00:01<00:00, 79.57it/s, loss=0.334, metrics={'acc': 0.8283059913495252}]\n",
" 0%| | 0/153 [00:00<?, ?it/s]" " 0%| | 0/153 [00:00<?, ?it/s]"
] ]
}, },
...@@ -345,16 +345,16 @@ ...@@ -345,16 +345,16 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 153/153 [00:01<00:00, 114.70it/s, loss=0.36, metrics={'acc': 0.8323}]\n", "epoch 1: 100%|██████████| 153/153 [00:01<00:00, 114.10it/s, loss=0.36, metrics={'acc': 0.8323}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 120.50it/s, loss=0.364, metrics={'acc': 0.8325}]\n", "valid: 100%|██████████| 39/39 [00:00<00:00, 123.16it/s, loss=0.364, metrics={'acc': 0.8325}]\n",
"epoch 2: 100%|██████████| 153/153 [00:01<00:00, 112.75it/s, loss=0.359, metrics={'acc': 0.8324}]\n", "epoch 2: 100%|██████████| 153/153 [00:01<00:00, 113.50it/s, loss=0.359, metrics={'acc': 0.8325}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 119.57it/s, loss=0.364, metrics={'acc': 0.8326}]\n", "valid: 100%|██████████| 39/39 [00:00<00:00, 122.56it/s, loss=0.364, metrics={'acc': 0.8327}]\n",
"epoch 3: 100%|██████████| 153/153 [00:01<00:00, 114.84it/s, loss=0.359, metrics={'acc': 0.8323}]\n", "epoch 3: 100%|██████████| 153/153 [00:01<00:00, 110.90it/s, loss=0.359, metrics={'acc': 0.8325}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 121.03it/s, loss=0.363, metrics={'acc': 0.8326}]\n", "valid: 100%|██████████| 39/39 [00:00<00:00, 119.56it/s, loss=0.363, metrics={'acc': 0.8327}]\n",
"epoch 4: 100%|██████████| 153/153 [00:01<00:00, 114.46it/s, loss=0.359, metrics={'acc': 0.8324}]\n", "epoch 4: 100%|██████████| 153/153 [00:01<00:00, 112.92it/s, loss=0.359, metrics={'acc': 0.8326}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 122.44it/s, loss=0.363, metrics={'acc': 0.8327}]\n", "valid: 100%|██████████| 39/39 [00:00<00:00, 121.00it/s, loss=0.363, metrics={'acc': 0.8329}]\n",
"epoch 5: 100%|██████████| 153/153 [00:01<00:00, 114.27it/s, loss=0.358, metrics={'acc': 0.833}] \n", "epoch 5: 100%|██████████| 153/153 [00:01<00:00, 114.15it/s, loss=0.358, metrics={'acc': 0.8327}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 119.25it/s, loss=0.363, metrics={'acc': 0.8332}]\n" "valid: 100%|██████████| 39/39 [00:00<00:00, 108.91it/s, loss=0.363, metrics={'acc': 0.8329}]\n"
] ]
} }
], ],
...@@ -450,7 +450,7 @@ ...@@ -450,7 +450,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
" 8%|▊ | 78/1001 [00:00<00:02, 387.42it/s]" " 8%|▊ | 84/1001 [00:00<00:02, 416.73it/s]"
] ]
}, },
{ {
...@@ -464,7 +464,7 @@ ...@@ -464,7 +464,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"100%|██████████| 1001/1001 [00:02<00:00, 392.18it/s]\n" "100%|██████████| 1001/1001 [00:02<00:00, 400.82it/s]\n"
] ]
}, },
{ {
...@@ -848,7 +848,7 @@ ...@@ -848,7 +848,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 25/25 [00:00<00:00, 58.09it/s, loss=128]\n", "epoch 1: 100%|██████████| 25/25 [00:00<00:00, 58.03it/s, loss=127]\n",
" 0%| | 0/25 [00:00<?, ?it/s]" " 0%| | 0/25 [00:00<?, ?it/s]"
] ]
}, },
...@@ -863,7 +863,7 @@ ...@@ -863,7 +863,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 25/25 [00:00<00:00, 45.57it/s, loss=119]\n", "epoch 1: 100%|██████████| 25/25 [00:00<00:00, 47.80it/s, loss=116]\n",
" 0%| | 0/25 [00:00<?, ?it/s]" " 0%| | 0/25 [00:00<?, ?it/s]"
] ]
}, },
...@@ -878,7 +878,7 @@ ...@@ -878,7 +878,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 25/25 [00:04<00:00, 6.12it/s, loss=132]\n", "epoch 1: 100%|██████████| 25/25 [00:04<00:00, 5.94it/s, loss=132]\n",
" 0%| | 0/25 [00:00<?, ?it/s]" " 0%| | 0/25 [00:00<?, ?it/s]"
] ]
}, },
...@@ -893,7 +893,7 @@ ...@@ -893,7 +893,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 25/25 [01:06<00:00, 2.65s/it, loss=119]\n", "epoch 1: 100%|██████████| 25/25 [01:12<00:00, 2.92s/it, loss=119]\n",
" 0%| | 0/25 [00:00<?, ?it/s]" " 0%| | 0/25 [00:00<?, ?it/s]"
] ]
}, },
...@@ -908,7 +908,7 @@ ...@@ -908,7 +908,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 25/25 [01:33<00:00, 3.72s/it, loss=108]\n", "epoch 1: 100%|██████████| 25/25 [01:48<00:00, 4.34s/it, loss=108]\n",
" 0%| | 0/25 [00:00<?, ?it/s]" " 0%| | 0/25 [00:00<?, ?it/s]"
] ]
}, },
...@@ -923,7 +923,7 @@ ...@@ -923,7 +923,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 25/25 [01:56<00:00, 4.65s/it, loss=106]\n", "epoch 1: 100%|██████████| 25/25 [02:05<00:00, 5.01s/it, loss=106] \n",
" 0%| | 0/25 [00:00<?, ?it/s]" " 0%| | 0/25 [00:00<?, ?it/s]"
] ]
}, },
...@@ -938,7 +938,7 @@ ...@@ -938,7 +938,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 25/25 [02:23<00:00, 5.75s/it, loss=105] \n", "epoch 1: 100%|██████████| 25/25 [02:57<00:00, 7.11s/it, loss=105] \n",
" 0%| | 0/25 [00:00<?, ?it/s]" " 0%| | 0/25 [00:00<?, ?it/s]"
] ]
}, },
...@@ -953,7 +953,7 @@ ...@@ -953,7 +953,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 25/25 [02:53<00:00, 6.94s/it, loss=105] \n", "epoch 1: 100%|██████████| 25/25 [03:40<00:00, 8.83s/it, loss=104] \n",
" 0%| | 0/25 [00:00<?, ?it/s]" " 0%| | 0/25 [00:00<?, ?it/s]"
] ]
}, },
...@@ -968,8 +968,8 @@ ...@@ -968,8 +968,8 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 25/25 [01:13<00:00, 2.92s/it, loss=120]\n", "epoch 1: 100%|██████████| 25/25 [01:20<00:00, 3.23s/it, loss=120]\n",
"valid: 100%|██████████| 7/7 [00:14<00:00, 2.14s/it, loss=109] \n" "valid: 100%|██████████| 7/7 [00:14<00:00, 2.06s/it, loss=109] \n"
] ]
} }
], ],
......
...@@ -4,7 +4,7 @@ import pandas as pd ...@@ -4,7 +4,7 @@ import pandas as pd
from pytorch_widedeep.optim import RAdam from pytorch_widedeep.optim import RAdam
from pytorch_widedeep.models import Wide, WideDeep, DeepDense from pytorch_widedeep.models import Wide, WideDeep, DeepDense
from pytorch_widedeep.metrics import BinaryAccuracy from pytorch_widedeep.metrics import Accuracy, Precision
from pytorch_widedeep.callbacks import ( from pytorch_widedeep.callbacks import (
LRHistory, LRHistory,
EarlyStopping, EarlyStopping,
...@@ -76,7 +76,7 @@ if __name__ == "__main__": ...@@ -76,7 +76,7 @@ if __name__ == "__main__":
EarlyStopping, EarlyStopping,
ModelCheckpoint(filepath="model_weights/wd_out"), ModelCheckpoint(filepath="model_weights/wd_out"),
] ]
metrics = [BinaryAccuracy] metrics = [Accuracy, Precision]
model.compile( model.compile(
method="binary", method="binary",
......
...@@ -8,9 +8,8 @@ from collections import Counter ...@@ -8,9 +8,8 @@ from collections import Counter
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer
import gender_guesser.detector as gender import gender_guesser.detector as gender
from sklearn.preprocessing import MultiLabelBinarizer
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
import pandas as pd import pandas as pd
from pytorch_widedeep.models import Wide, WideDeep, DeepDense from pytorch_widedeep.models import Wide, WideDeep, DeepDense
from pytorch_widedeep.metrics import CategoricalAccuracy from pytorch_widedeep.metrics import F1Score, Accuracy
from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
...@@ -48,7 +48,7 @@ if __name__ == "__main__": ...@@ -48,7 +48,7 @@ if __name__ == "__main__":
continuous_cols=continuous_cols, continuous_cols=continuous_cols,
) )
model = WideDeep(wide=wide, deepdense=deepdense, pred_dim=3) model = WideDeep(wide=wide, deepdense=deepdense, pred_dim=3)
model.compile(method="multiclass", metrics=[CategoricalAccuracy]) model.compile(method="multiclass", metrics=[Accuracy, F1Score])
model.fit( model.fit(
X_wide=X_wide, X_wide=X_wide,
......
[![Build Status](https://travis-ci.org/jrzaurin/pytorch-widedeep.svg?branch=master)](https://travis-ci.org/jrzaurin/pytorch-widedeep)
[![Documentation Status](https://readthedocs.org/projects/pytorch-widedeep/badge/?version=latest)](https://pytorch-widedeep.readthedocs.io/en/latest/?badge=latest)
# pytorch-widedeep # pytorch-widedeep
A flexible package to combine tabular data with text and images using wide and A flexible package to combine tabular data with text and images using wide and
...@@ -64,7 +68,7 @@ from sklearn.model_selection import train_test_split ...@@ -64,7 +68,7 @@ from sklearn.model_selection import train_test_split
from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor
from pytorch_widedeep.models import Wide, DeepDense, WideDeep from pytorch_widedeep.models import Wide, DeepDense, WideDeep
from pytorch_widedeep.metrics import BinaryAccuracy from pytorch_widedeep.metrics import Accuracy
# these next 4 lines are not directly related to pytorch-widedeep. I assume # these next 4 lines are not directly related to pytorch-widedeep. I assume
# you have downloaded the dataset and place it in a dir called data/adult/ # you have downloaded the dataset and place it in a dir called data/adult/
...@@ -112,7 +116,7 @@ deepdense = DeepDense( ...@@ -112,7 +116,7 @@ deepdense = DeepDense(
# build, compile and fit # build, compile and fit
model = WideDeep(wide=wide, deepdense=deepdense) model = WideDeep(wide=wide, deepdense=deepdense)
model.compile(method="binary", metrics=[BinaryAccuracy]) model.compile(method="binary", metrics=[Accuracy])
model.fit( model.fit(
X_wide=X_wide, X_wide=X_wide,
X_deep=X_deep, X_deep=X_deep,
......
import numpy as np import numpy as np
import torch
from .wdtypes import * from .wdtypes import *
from .callbacks import Callback from .callbacks import Callback
...@@ -46,23 +47,17 @@ class MetricCallback(Callback): ...@@ -46,23 +47,17 @@ class MetricCallback(Callback):
self.container.reset() self.container.reset()
class CategoricalAccuracy(Metric): class Accuracy(Metric):
r"""Class to calculate the categorical accuracy for multicategorical problems r"""Class to calculate the accuracy for both binary and categorical problems
Parameters Parameters
---------- ----------
top_k: int top_k: int, default = 1
Accuracy will be computed using the top k most likely classes Accuracy will be computed using the top k most likely classes in
multiclass problems
Examples
--------
>>> y_true = torch.from_numpy(np.random.choice(3, 100))
>>> y_pred = torch.from_numpy(np.random.rand(100, 3))
>>> metric = CategoricalAccuracy(top_k=top_k)
>>> acc = metric(y_pred, y_true)
""" """
def __init__(self, top_k=1): def __init__(self, top_k: int = 1):
self.top_k = top_k self.top_k = top_k
self.correct_count = 0 self.correct_count = 0
self.total_count = 0 self.total_count = 0
...@@ -77,41 +72,179 @@ class CategoricalAccuracy(Metric): ...@@ -77,41 +72,179 @@ class CategoricalAccuracy(Metric):
self.total_count = 0 self.total_count = 0
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
top_k = y_pred.topk(self.top_k, 1)[1] num_classes = y_pred.size(1)
true_k = y_true.view(len(y_true), 1).expand_as(top_k) # type: ignore
self.correct_count += top_k.eq(true_k).float().sum().item() if num_classes == 1:
y_pred = y_pred.round()
y_true = y_true.view(-1, 1)
elif num_classes > 1:
y_pred = y_pred.topk(self.top_k, 1)[1]
y_true = y_true.view(-1, 1).expand_as(y_pred) # type: ignore
self.correct_count += y_pred.eq(y_true).sum().item() # type: ignore
self.total_count += len(y_pred) # type: ignore self.total_count += len(y_pred) # type: ignore
accuracy = float(self.correct_count) / float(self.total_count) accuracy = float(self.correct_count) / float(self.total_count)
return np.round(accuracy, 4) return accuracy
class BinaryAccuracy(Metric): class Precision(Metric):
"""Class to calculate accuracy for binary classification problems r"""Class to calculate the precision for both binary and categorical problems
Examples Parameters
-------- ----------
>>> y_true = torch.from_numpy(np.random.choice(2, 100)).float() average: bool, default = True
>>> y_pred = deepcopy(y_true.view(-1, 1)).float() This applies only to multiclass problems. if `True` calculate
>>> metric = BinaryAccuracy() precision for each label, and find their unweighted mean.
>>> acc = metric(y_pred, y_true)
""" """
def __init__(self): def __init__(self, average: bool = True):
self.correct_count = 0 self.average = average
self.total_count = 0 self.true_positives = 0
self.all_positives = 0
self.eps = 1e-20
self._name = "acc" self._name = "prec"
def reset(self): def reset(self) -> None:
""" """
resets counters to 0 resets counters to 0
""" """
self.correct_count = 0 self.true_positives = 0
self.total_count = 0 self.all_positives = 0
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
y_pred_round = y_pred.round() num_class = y_pred.size(1)
self.correct_count += y_pred_round.eq(y_true.view(-1, 1)).float().sum().item()
self.total_count += len(y_pred) # type: ignore if num_class == 1:
accuracy = float(self.correct_count) / float(self.total_count) y_pred = y_pred.round()
return np.round(accuracy, 4) y_true = y_true.view(-1, 1)
elif num_class > 1:
y_true = torch.eye(num_class)[y_true.long()]
y_pred = y_pred.topk(1, 1)[1].view(-1)
y_pred = torch.eye(num_class)[y_pred.long()]
self.true_positives += (y_true * y_pred).sum(dim=0) # type:ignore
self.all_positives += y_pred.sum(dim=0) # type:ignore
precision = self.true_positives / (self.all_positives + self.eps)
if self.average:
return precision.mean().item() # type:ignore
else:
return precision
class Recall(Metric):
r"""Class to calculate the recall for both binary and categorical problems
Parameters
----------
average: bool, default = True
This applies only to multiclass problems. if `True` calculate recall
for each label, and find their unweighted mean.
"""
def __init__(self, average: bool = True):
self.average = average
self.true_positives = 0
self.actual_positives = 0
self.eps = 1e-20
self._name = "rec"
def reset(self) -> None:
"""
resets counters to 0
"""
self.true_positives = 0
self.actual_positives = 0
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
num_class = y_pred.size(1)
if num_class == 1:
y_pred = y_pred.round()
y_true = y_true.view(-1, 1)
elif num_class > 1:
y_true = torch.eye(num_class)[y_true.long()]
y_pred = y_pred.topk(1, 1)[1].view(-1)
y_pred = torch.eye(num_class)[y_pred.long()]
self.true_positives += (y_true * y_pred).sum(dim=0) # type: ignore
self.actual_positives += y_true.sum(dim=0) # type: ignore
recall = self.true_positives / (self.actual_positives + self.eps)
if self.average:
return recall.mean().item() # type:ignore
else:
return recall
class FBetaScore(Metric):
r"""Class to calculate the fbeta score for both binary and categorical problems
``FBeta = ((1 + beta^2) * Precision * Recall) / (beta^2 * Precision + Recall)``
Parameters
----------
beta: int
Coefficient to control the balance between precision and recall
average: bool, default = True
This applies only to multiclass problems. if `True` calculate fbeta
for each label, and find their unweighted mean.
"""
def __init__(self, beta: int, average: bool = True):
self.average = average
self.precision = Precision(average=False)
self.recall = Recall(average=False)
self.beta = beta
self._name = "".join(["f", str(beta)])
def reset(self) -> None:
"""
resets precision and recall
"""
self.precision.reset()
self.recall.reset()
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
prec = self.precision(y_pred, y_true)
rec = self.recall(y_pred, y_true)
beta2 = self.beta ** 2
fbeta = ((1 + beta2) * prec * rec) / (beta2 * prec + rec)
if self.average:
return fbeta.mean().item()
else:
return fbeta
class F1Score(Metric):
r"""Class to calculate the f1 score for both binary and categorical problems
Parameters
----------
average: bool, default = True
This applies only to multiclass problems. if `True` calculate f1 for
each label, and find their unweighted mean.
"""
def __init__(self, average: bool = True):
self.f1 = FBetaScore(beta=1, average=average)
self._name = self.f1._name
def reset(self) -> None:
"""
resets counters to 0
"""
self.f1.reset()
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
return self.f1(y_pred, y_true)
...@@ -266,9 +266,9 @@ class WideDeep(nn.Module): ...@@ -266,9 +266,9 @@ class WideDeep(nn.Module):
See the ``Callbacks`` section in this documentation or See the ``Callbacks`` section in this documentation or
:obj:`pytorch_widedeep.callbacks` :obj:`pytorch_widedeep.callbacks`
metrics: List[Metric], Optional. Default=None metrics: List[Metric], Optional. Default=None
Metrics available are: ``BinaryAccuracy`` and Metrics available are: ``Accuracy``, ``Precision``, ``Recall``,
``CategoricalAccuracy`` See the ``Metrics`` section in this ``FBetaScore`` and ``F1Score``. See the ``Metrics`` section in
documentation or :obj:`pytorch_widedeep.metrics` this documentation or :obj:`pytorch_widedeep.metrics`
class_weight: Union[float, List[float], Tuple[float]]. Optional. Default=None class_weight: Union[float, List[float], Tuple[float]]. Optional. Default=None
- float indicating the weight of the minority class in binary classification - float indicating the weight of the minority class in binary classification
problems (e.g. 9.) problems (e.g. 9.)
...@@ -587,17 +587,22 @@ class WideDeep(nn.Module): ...@@ -587,17 +587,22 @@ class WideDeep(nn.Module):
with trange(train_steps, disable=self.verbose != 1) as t: with trange(train_steps, disable=self.verbose != 1) as t:
for batch_idx, (data, target) in zip(t, train_loader): for batch_idx, (data, target) in zip(t, train_loader):
t.set_description("epoch %i" % (epoch + 1)) t.set_description("epoch %i" % (epoch + 1))
acc, train_loss = self._training_step(data, target, batch_idx) score, train_loss = self._training_step(data, target, batch_idx)
if acc is not None: if score is not None:
t.set_postfix(metrics=acc, loss=train_loss) t.set_postfix(
metrics={k: np.round(v, 4) for k, v in score.items()},
loss=train_loss,
)
else: else:
t.set_postfix(loss=np.sqrt(train_loss)) t.set_postfix(loss=np.sqrt(train_loss))
if self.lr_scheduler: if self.lr_scheduler:
self._lr_scheduler_step(step_location="on_batch_end") self._lr_scheduler_step(step_location="on_batch_end")
self.callback_container.on_batch_end(batch=batch_idx) self.callback_container.on_batch_end(batch=batch_idx)
epoch_logs["train_loss"] = train_loss epoch_logs["train_loss"] = train_loss
if acc is not None: if score is not None:
epoch_logs["train_acc"] = acc["acc"] for k, v in score.items():
log_k = "_".join(["train", k])
epoch_logs[log_k] = v
# eval step... # eval step...
if epoch % validation_freq == (validation_freq - 1): if epoch % validation_freq == (validation_freq - 1):
if eval_set is not None: if eval_set is not None:
...@@ -612,14 +617,21 @@ class WideDeep(nn.Module): ...@@ -612,14 +617,21 @@ class WideDeep(nn.Module):
with trange(eval_steps, disable=self.verbose != 1) as v: with trange(eval_steps, disable=self.verbose != 1) as v:
for i, (data, target) in zip(v, eval_loader): for i, (data, target) in zip(v, eval_loader):
v.set_description("valid") v.set_description("valid")
acc, val_loss = self._validation_step(data, target, i) score, val_loss = self._validation_step(data, target, i)
if acc is not None: if score is not None:
v.set_postfix(metrics=acc, loss=val_loss) v.set_postfix(
metrics={
k: np.round(v, 4) for k, v in score.items()
},
loss=val_loss,
)
else: else:
v.set_postfix(loss=np.sqrt(val_loss)) v.set_postfix(loss=np.sqrt(val_loss))
epoch_logs["val_loss"] = val_loss epoch_logs["val_loss"] = val_loss
if acc is not None: if score is not None:
epoch_logs["val_acc"] = acc["acc"] for k, v in score.items():
log_k = "_".join(["val", k])
epoch_logs[log_k] = v
if self.lr_scheduler: if self.lr_scheduler:
self._lr_scheduler_step(step_location="on_epoch_end") self._lr_scheduler_step(step_location="on_epoch_end")
#  log and check if early_stop... #  log and check if early_stop...
...@@ -986,10 +998,10 @@ class WideDeep(nn.Module): ...@@ -986,10 +998,10 @@ class WideDeep(nn.Module):
if self.metric is not None: if self.metric is not None:
if self.method == "binary": if self.method == "binary":
acc = self.metric(torch.sigmoid(y_pred), y) score = self.metric(torch.sigmoid(y_pred), y)
if self.method == "multiclass": if self.method == "multiclass":
acc = self.metric(F.softmax(y_pred, dim=1), y) score = self.metric(F.softmax(y_pred, dim=1), y)
return acc, avg_loss return score, avg_loss
else: else:
return None, avg_loss return None, avg_loss
...@@ -1008,10 +1020,10 @@ class WideDeep(nn.Module): ...@@ -1008,10 +1020,10 @@ class WideDeep(nn.Module):
if self.metric is not None: if self.metric is not None:
if self.method == "binary": if self.method == "binary":
acc = self.metric(torch.sigmoid(y_pred), y) score = self.metric(torch.sigmoid(y_pred), y)
if self.method == "multiclass": if self.method == "multiclass":
acc = self.metric(F.softmax(y_pred, dim=1), y) score = self.metric(F.softmax(y_pred, dim=1), y)
return acc, avg_loss return score, avg_loss
else: else:
return None, avg_loss return None, avg_loss
......
__version__ = "0.4.1" __version__ = "0.4.2"
...@@ -3,23 +3,113 @@ from copy import deepcopy ...@@ -3,23 +3,113 @@ from copy import deepcopy
import numpy as np import numpy as np
import torch import torch
import pytest import pytest
from sklearn.metrics import (
f1_score,
fbeta_score,
recall_score,
accuracy_score,
precision_score,
)
from pytorch_widedeep.metrics import BinaryAccuracy, CategoricalAccuracy from pytorch_widedeep.metrics import (
Recall,
F1Score,
Accuracy,
Precision,
FBetaScore,
)
y_true = torch.from_numpy(np.random.choice(2, 100)).float()
y_pred = deepcopy(y_true.view(-1, 1)).float()
def f2_score_bin(y_true, y_pred):
return fbeta_score(y_true, y_pred, beta=2)
def test_binary_accuracy():
metric = BinaryAccuracy() y_true_bin_np = np.array([1, 0, 0, 0, 1, 1, 0])
acc = metric(y_pred, y_true) y_pred_bin_np = np.array([0.6, 0.3, 0.2, 0.8, 0.4, 0.9, 0.6])
assert acc == 1.0
y_true_bin_pt = torch.from_numpy(y_true_bin_np)
y_pred_bin_pt = torch.from_numpy(y_pred_bin_np).view(-1, 1)
###############################################################################
# Test binary metrics
###############################################################################
@pytest.mark.parametrize(
"sklearn_metric, widedeep_metric",
[
(accuracy_score, Accuracy()),
(precision_score, Precision()),
(recall_score, Recall()),
(f1_score, F1Score()),
(f2_score_bin, FBetaScore(beta=2)),
],
)
def test_binary_metrics(sklearn_metric, widedeep_metric):
assert np.isclose(
sklearn_metric(y_true_bin_np, y_pred_bin_np.round()),
widedeep_metric(y_pred_bin_pt, y_true_bin_pt),
)
###############################################################################
# Test top_k for Accuracy
###############################################################################
@pytest.mark.parametrize("top_k, expected_acc", [(1, 0.33), (2, 0.66)]) @pytest.mark.parametrize("top_k, expected_acc", [(1, 0.33), (2, 0.66)])
def test_categorical_accuracy(top_k, expected_acc): def test_categorical_accuracy_topk(top_k, expected_acc):
y_true = torch.from_numpy(np.random.choice(3, 100)) y_true = torch.from_numpy(np.random.choice(3, 100))
y_pred = torch.from_numpy(np.random.rand(100, 3)) y_pred = torch.from_numpy(np.random.rand(100, 3))
metric = CategoricalAccuracy(top_k=top_k) metric = Accuracy(top_k=top_k)
acc = metric(y_pred, y_true) acc = metric(y_pred, y_true)
assert np.isclose(acc, expected_acc, atol=0.3) assert np.isclose(acc, expected_acc, atol=0.3)
###############################################################################
# Test multiclass metrics
###############################################################################
y_true_multi_np = np.array([1, 0, 2, 1, 1, 2, 2, 0, 0, 0])
y_pred_muli_np = np.array(
[
[0.2, 0.6, 0.2],
[0.4, 0.5, 0.1],
[0.1, 0.1, 0.8],
[0.1, 0.6, 0.3],
[0.1, 0.8, 0.1],
[0.1, 0.6, 0.6],
[0.2, 0.6, 0.8],
[0.6, 0.1, 0.3],
[0.7, 0.2, 0.1],
[0.1, 0.7, 0.2],
]
)
y_true_multi_pt = torch.from_numpy(y_true_multi_np)
y_pred_multi_pt = torch.from_numpy(y_pred_muli_np)
def f2_score_multi(y_true, y_pred, average):
return fbeta_score(y_true, y_pred, average=average, beta=2)
@pytest.mark.parametrize(
"sklearn_metric, widedeep_metric",
[
(accuracy_score, Accuracy()),
(precision_score, Precision()),
(recall_score, Recall()),
(f1_score, F1Score()),
(f2_score_multi, FBetaScore(beta=2)),
],
)
def test_muticlass_metrics(sklearn_metric, widedeep_metric):
if sklearn_metric.__name__ == "accuracy_score":
assert np.isclose(
sklearn_metric(y_true_multi_np, y_pred_muli_np.argmax(axis=1)),
widedeep_metric(y_pred_multi_pt, y_true_multi_pt),
)
else:
assert np.isclose(
sklearn_metric(
y_true_multi_np, y_pred_muli_np.argmax(axis=1), average="macro"
),
widedeep_metric(y_pred_multi_pt, y_true_multi_pt),
)
...@@ -9,7 +9,7 @@ from sklearn.utils import Bunch ...@@ -9,7 +9,7 @@ from sklearn.utils import Bunch
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from pytorch_widedeep.models import Wide, DeepDense from pytorch_widedeep.models import Wide, DeepDense
from pytorch_widedeep.metrics import BinaryAccuracy from pytorch_widedeep.metrics import Accuracy
from pytorch_widedeep.models._warmup import WarmUp from pytorch_widedeep.models._warmup import WarmUp
from pytorch_widedeep.models.deep_image import conv_layer from pytorch_widedeep.models.deep_image import conv_layer
...@@ -138,7 +138,7 @@ wdset = WDset(X_wide, X_deep, X_text, X_image, target) ...@@ -138,7 +138,7 @@ wdset = WDset(X_wide, X_deep, X_text, X_image, target)
wdloader = DataLoader(wdset, batch_size=10, shuffle=True) wdloader = DataLoader(wdset, batch_size=10, shuffle=True)
# Instantiate the WarmUp class # Instantiate the WarmUp class
warmer = WarmUp(loss_fn, BinaryAccuracy(), "binary", False) warmer = WarmUp(loss_fn, Accuracy(), "binary", False)
# List the layers for the warm_gradual method # List the layers for the warm_gradual method
text_layers = [c for c in list(deeptext.children())[1:]][::-1] text_layers = [c for c in list(deeptext.children())[1:]][::-1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册