未验证 提交 65465a45 编写于 作者: J Javier 提交者: GitHub

Merge pull request #17 from jrzaurin/precision_recall

Precision recall
......@@ -40,7 +40,9 @@ final output neuron or neurons, depending on whether we are performing a
binary classification or regression, or a multi-class classification. The
components within the faded-pink rectangles are concatenated.
In math terms, and following the notation in the [paper](https://arxiv.org/abs/1606.07792), Architecture 1 can be formulated as:
In math terms, and following the notation in the
[paper](https://arxiv.org/abs/1606.07792), Architecture 1 can be formulated
as:
<p align="center">
<img width="500" src="docs/figures/architecture_1_math.png">
......@@ -130,7 +132,7 @@ from sklearn.model_selection import train_test_split
from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor
from pytorch_widedeep.models import Wide, DeepDense, WideDeep
from pytorch_widedeep.metrics import BinaryAccuracy
from pytorch_widedeep.metrics import Accuracy
# these next 4 lines are not directly related to pytorch-widedeep. I assume
# you have downloaded the dataset and place it in a dir called data/adult/
......@@ -178,7 +180,7 @@ deepdense = DeepDense(
# build, compile and fit
model = WideDeep(wide=wide, deepdense=deepdense)
model.compile(method="binary", metrics=[BinaryAccuracy])
model.compile(method="binary", metrics=[Accuracy])
model.fit(
X_wide=X_wide,
X_deep=X_deep,
......
0.4.1
\ No newline at end of file
0.4.2
\ No newline at end of file
# sort imports
isort --recursive . pytorch_widedeep tests examples setup.py
isort . pytorch_widedeep tests examples setup.py
# Black code style
black . pytorch_widedeep tests examples setup.py
# flake8 standards
......
pytorch-widedeep Examples
*****************************
This section provides links to example notebooks that may be helpful to better
understand the functionalities withing ``pytorch-widedeep`` and how to use
them to address different problems
* `Preprocessors and Utils <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/01_Preprocessors_and_utils.ipynb>`__
* `Model Components <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/02_Model_Components.ipynb>`__
* `Binary Classification with default parameters <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/03_Binary_Classification_with_Defaults.ipynb>`__
* `Binary Classification with varying parameters <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/04_Binary_Classification_Varying_Parameters.ipynb>`__
* `Regression with Images and Text <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/05_Regression_with_Images_and_Text.ipynb>`__
* `Warm up routines <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/06_WarmUp_Model_Components.ipynb>`__
......@@ -19,6 +19,7 @@ Documentation
Preprocessing <preprocessing>
Model Components <model_components>
Wide and Deep Models <wide_deep/index>
Examples <examples>
Introduction
......
......@@ -30,7 +30,7 @@ Prepare the wide and deep columns
from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor
from pytorch_widedeep.models import Wide, DeepDense, WideDeep
from pytorch_widedeep.metrics import BinaryAccuracy
from pytorch_widedeep.metrics import Accuracy
# prepare wide, crossed, embedding and continuous columns
wide_cols = [
......@@ -83,7 +83,7 @@ Build, compile, fit and predict
# build, compile and fit
model = WideDeep(wide=wide, deepdense=deepdense)
model.compile(method="binary", metrics=[BinaryAccuracy])
model.compile(method="binary", metrics=[Accuracy])
model.fit(
X_wide=X_wide,
X_deep=X_deep,
......
Metrics
=======
.. autoclass:: pytorch_widedeep.metrics.BinaryAccuracy
.. autoclass:: pytorch_widedeep.metrics.Accuracy
:members:
:undoc-members:
:show-inheritance:
.. autoclass:: pytorch_widedeep.metrics.CategoricalAccuracy
.. autoclass:: pytorch_widedeep.metrics.Precision
:members:
:undoc-members:
:show-inheritance:
.. autoclass:: pytorch_widedeep.metrics.Recall
:members:
:undoc-members:
:show-inheritance:
.. autoclass:: pytorch_widedeep.metrics.FBetaScore
:members:
:undoc-members:
:show-inheritance:
.. autoclass:: pytorch_widedeep.metrics.F1Score
:members:
:undoc-members:
:show-inheritance:
......@@ -170,11 +170,11 @@
{
"data": {
"text/plain": [
"tensor([[-0.0000, -1.0061, -0.0000, -0.9828, -0.0000, -0.0000, -0.9944, -1.0133],\n",
" [-0.0000, -0.9996, 0.0000, -1.0374, 0.0000, -0.0000, -1.0313, -0.0000],\n",
" [-0.8576, -1.0017, -0.0000, -0.9881, -0.0000, 0.0000, -0.0000, -0.0000],\n",
" [ 3.9816, 0.0000, 0.0000, 0.0000, 3.7309, 1.1728, 0.0000, -1.1160],\n",
" [-1.1339, -0.9925, -0.0000, -0.0000, -0.0000, 0.0000, -0.9638, 0.0000]],\n",
"tensor([[-0.0000, -0.9949, 3.8273, 0.0000, -1.3889, -2.9641, 0.0000, -0.0000],\n",
" [ 3.9123, -0.0000, -0.0000, 1.9555, -1.3561, 1.7069, -0.0000, 0.9275],\n",
" [-0.0000, -0.0000, 0.0000, -0.0000, 0.0000, -1.6489, -0.0000, -1.4985],\n",
" [-1.2736, 0.0000, -1.2819, 2.1232, 0.0000, 2.2767, -0.0000, 3.5354],\n",
" [-0.1726, -0.0000, -1.3275, -0.0000, -1.3703, 0.0000, -0.0000, -1.4637]],\n",
" grad_fn=<MulBackward0>)"
]
},
......@@ -484,10 +484,10 @@
{
"data": {
"text/plain": [
"tensor([[-1.4630e-04, -6.1540e-04, -2.4541e-04, 2.7543e-01, 1.2993e-01,\n",
" -1.6553e-03, 6.7002e-02, 2.3974e-01],\n",
" [-9.9619e-04, -1.9412e-03, 1.2113e-01, 1.0122e-01, 2.9080e-01,\n",
" -2.0852e-03, -1.8016e-04, 2.7996e-02]], grad_fn=<LeakyReluBackward1>)"
"tensor([[-2.2825e-03, -8.3100e-04, -8.8423e-04, -1.1084e-04, 8.8529e-02,\n",
" -5.1577e-04, 2.8343e-01, -1.7071e-03],\n",
" [-1.8486e-03, -8.5602e-04, -1.8552e-03, 3.6481e-01, 9.0812e-02,\n",
" -9.6603e-04, 3.9017e-01, -2.6355e-03]], grad_fn=<LeakyReluBackward1>)"
]
},
"execution_count": 18,
......
......@@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
......@@ -21,12 +21,12 @@
"\n",
"from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor\n",
"from pytorch_widedeep.models import Wide, DeepDense, WideDeep\n",
"from pytorch_widedeep.metrics import BinaryAccuracy"
"from pytorch_widedeep.metrics import Accuracy, Precision"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 2,
"metadata": {},
"outputs": [
{
......@@ -185,7 +185,7 @@
"4 30 United-States <=50K "
]
},
"execution_count": 6,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
......@@ -197,7 +197,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 3,
"metadata": {},
"outputs": [
{
......@@ -356,7 +356,7 @@
"4 30 United-States 0 "
]
},
"execution_count": 7,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
......@@ -381,7 +381,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
......@@ -394,7 +394,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
......@@ -412,7 +412,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 6,
"metadata": {},
"outputs": [
{
......@@ -437,7 +437,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 7,
"metadata": {},
"outputs": [
{
......@@ -475,7 +475,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
......@@ -489,7 +489,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 9,
"metadata": {},
"outputs": [
{
......@@ -527,7 +527,7 @@
")"
]
},
"execution_count": 15,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
......@@ -560,16 +560,16 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"model.compile(method='binary', metrics=[BinaryAccuracy])"
"model.compile(method='binary', metrics=[Accuracy, Precision])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 11,
"metadata": {},
"outputs": [
{
......@@ -591,16 +591,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 153/153 [00:02<00:00, 64.79it/s, loss=0.435, metrics={'acc': 0.7901}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 124.97it/s, loss=0.358, metrics={'acc': 0.799}]\n",
"epoch 2: 100%|██████████| 153/153 [00:02<00:00, 71.36it/s, loss=0.352, metrics={'acc': 0.8352}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 124.33it/s, loss=0.349, metrics={'acc': 0.8358}]\n",
"epoch 3: 100%|██████████| 153/153 [00:02<00:00, 72.24it/s, loss=0.345, metrics={'acc': 0.8383}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 121.07it/s, loss=0.345, metrics={'acc': 0.8389}]\n",
"epoch 4: 100%|██████████| 153/153 [00:02<00:00, 70.39it/s, loss=0.341, metrics={'acc': 0.8404}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 123.29it/s, loss=0.343, metrics={'acc': 0.8406}]\n",
"epoch 5: 100%|██████████| 153/153 [00:02<00:00, 71.14it/s, loss=0.339, metrics={'acc': 0.8423}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 121.12it/s, loss=0.342, metrics={'acc': 0.8426}]\n"
"epoch 1: 100%|██████████| 153/153 [00:01<00:00, 102.41it/s, loss=0.585, metrics={'acc': 0.7512, 'prec': 0.1818}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 98.78it/s, loss=0.513, metrics={'acc': 0.754, 'prec': 0.2429}] \n",
"epoch 2: 100%|██████████| 153/153 [00:01<00:00, 117.30it/s, loss=0.481, metrics={'acc': 0.782, 'prec': 0.8287}] \n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 106.49it/s, loss=0.454, metrics={'acc': 0.7866, 'prec': 0.8245}]\n",
"epoch 3: 100%|██████████| 153/153 [00:01<00:00, 124.78it/s, loss=0.44, metrics={'acc': 0.8055, 'prec': 0.781}] \n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 115.36it/s, loss=0.425, metrics={'acc': 0.8077, 'prec': 0.7818}]\n",
"epoch 4: 100%|██████████| 153/153 [00:01<00:00, 125.01it/s, loss=0.418, metrics={'acc': 0.814, 'prec': 0.7661}] \n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 114.92it/s, loss=0.408, metrics={'acc': 0.8149, 'prec': 0.7671}]\n",
"epoch 5: 100%|██████████| 153/153 [00:01<00:00, 116.57it/s, loss=0.404, metrics={'acc': 0.819, 'prec': 0.7527}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 108.89it/s, loss=0.397, metrics={'acc': 0.8203, 'prec': 0.7547}]\n"
]
}
],
......
......@@ -43,7 +43,7 @@
"\n",
"from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor\n",
"from pytorch_widedeep.models import Wide, DeepDense, WideDeep\n",
"from pytorch_widedeep.metrics import BinaryAccuracy"
"from pytorch_widedeep.metrics import Accuracy"
]
},
{
......@@ -273,7 +273,7 @@
"metadata": {},
"outputs": [],
"source": [
"model.compile(method='binary', metrics=[BinaryAccuracy])"
"model.compile(method='binary', metrics=[Accuracy])"
]
},
{
......@@ -307,11 +307,11 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 153/153 [00:01<00:00, 118.14it/s, loss=0.475, metrics={'acc': 0.7854}]\n",
"epoch 2: 100%|██████████| 153/153 [00:00<00:00, 154.41it/s, loss=0.373, metrics={'acc': 0.8069}]\n",
"epoch 3: 100%|██████████| 153/153 [00:00<00:00, 153.93it/s, loss=0.365, metrics={'acc': 0.8151}]\n",
"epoch 4: 100%|██████████| 153/153 [00:00<00:00, 154.42it/s, loss=0.362, metrics={'acc': 0.8193}]\n",
"epoch 5: 100%|██████████| 153/153 [00:01<00:00, 147.62it/s, loss=0.36, metrics={'acc': 0.8219}]\n",
"epoch 1: 100%|██████████| 153/153 [00:01<00:00, 127.54it/s, loss=0.476, metrics={'acc': 0.7808972948071559}]\n",
"epoch 2: 100%|██████████| 153/153 [00:01<00:00, 126.88it/s, loss=0.373, metrics={'acc': 0.8048268625393494}]\n",
"epoch 3: 100%|██████████| 153/153 [00:01<00:00, 141.92it/s, loss=0.365, metrics={'acc': 0.8136820822562895}]\n",
"epoch 4: 100%|██████████| 153/153 [00:01<00:00, 151.56it/s, loss=0.362, metrics={'acc': 0.8182312594374632}]\n",
"epoch 5: 100%|██████████| 153/153 [00:00<00:00, 158.22it/s, loss=0.36, metrics={'acc': 0.8210477823561027}]\n",
" 0%| | 0/153 [00:00<?, ?it/s]"
]
},
......@@ -326,11 +326,11 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 153/153 [00:02<00:00, 75.79it/s, loss=0.392, metrics={'acc': 0.8209}]\n",
"epoch 2: 100%|██████████| 153/153 [00:01<00:00, 76.97it/s, loss=0.35, metrics={'acc': 0.823}] \n",
"epoch 3: 100%|██████████| 153/153 [00:02<00:00, 75.86it/s, loss=0.344, metrics={'acc': 0.8251}]\n",
"epoch 4: 100%|██████████| 153/153 [00:01<00:00, 76.79it/s, loss=0.34, metrics={'acc': 0.8269}] \n",
"epoch 5: 100%|██████████| 153/153 [00:01<00:00, 77.39it/s, loss=0.335, metrics={'acc': 0.8286}]\n",
"epoch 1: 100%|██████████| 153/153 [00:01<00:00, 78.65it/s, loss=0.397, metrics={'acc': 0.8198073691125158}]\n",
"epoch 2: 100%|██████████| 153/153 [00:02<00:00, 75.69it/s, loss=0.348, metrics={'acc': 0.8221936229255862}]\n",
"epoch 3: 100%|██████████| 153/153 [00:02<00:00, 74.79it/s, loss=0.343, metrics={'acc': 0.8243576126737133}]\n",
"epoch 4: 100%|██████████| 153/153 [00:01<00:00, 76.79it/s, loss=0.338, metrics={'acc': 0.8264502057402526}]\n",
"epoch 5: 100%|██████████| 153/153 [00:01<00:00, 79.57it/s, loss=0.334, metrics={'acc': 0.8283059913495252}]\n",
" 0%| | 0/153 [00:00<?, ?it/s]"
]
},
......@@ -345,16 +345,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 153/153 [00:01<00:00, 114.70it/s, loss=0.36, metrics={'acc': 0.8323}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 120.50it/s, loss=0.364, metrics={'acc': 0.8325}]\n",
"epoch 2: 100%|██████████| 153/153 [00:01<00:00, 112.75it/s, loss=0.359, metrics={'acc': 0.8324}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 119.57it/s, loss=0.364, metrics={'acc': 0.8326}]\n",
"epoch 3: 100%|██████████| 153/153 [00:01<00:00, 114.84it/s, loss=0.359, metrics={'acc': 0.8323}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 121.03it/s, loss=0.363, metrics={'acc': 0.8326}]\n",
"epoch 4: 100%|██████████| 153/153 [00:01<00:00, 114.46it/s, loss=0.359, metrics={'acc': 0.8324}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 122.44it/s, loss=0.363, metrics={'acc': 0.8327}]\n",
"epoch 5: 100%|██████████| 153/153 [00:01<00:00, 114.27it/s, loss=0.358, metrics={'acc': 0.833}] \n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 119.25it/s, loss=0.363, metrics={'acc': 0.8332}]\n"
"epoch 1: 100%|██████████| 153/153 [00:01<00:00, 114.10it/s, loss=0.36, metrics={'acc': 0.8323}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 123.16it/s, loss=0.364, metrics={'acc': 0.8325}]\n",
"epoch 2: 100%|██████████| 153/153 [00:01<00:00, 113.50it/s, loss=0.359, metrics={'acc': 0.8325}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 122.56it/s, loss=0.364, metrics={'acc': 0.8327}]\n",
"epoch 3: 100%|██████████| 153/153 [00:01<00:00, 110.90it/s, loss=0.359, metrics={'acc': 0.8325}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 119.56it/s, loss=0.363, metrics={'acc': 0.8327}]\n",
"epoch 4: 100%|██████████| 153/153 [00:01<00:00, 112.92it/s, loss=0.359, metrics={'acc': 0.8326}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 121.00it/s, loss=0.363, metrics={'acc': 0.8329}]\n",
"epoch 5: 100%|██████████| 153/153 [00:01<00:00, 114.15it/s, loss=0.358, metrics={'acc': 0.8327}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 108.91it/s, loss=0.363, metrics={'acc': 0.8329}]\n"
]
}
],
......@@ -450,7 +450,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 8%|▊ | 78/1001 [00:00<00:02, 387.42it/s]"
" 8%|▊ | 84/1001 [00:00<00:02, 416.73it/s]"
]
},
{
......@@ -464,7 +464,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1001/1001 [00:02<00:00, 392.18it/s]\n"
"100%|██████████| 1001/1001 [00:02<00:00, 400.82it/s]\n"
]
},
{
......@@ -848,7 +848,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [00:00<00:00, 58.09it/s, loss=128]\n",
"epoch 1: 100%|██████████| 25/25 [00:00<00:00, 58.03it/s, loss=127]\n",
" 0%| | 0/25 [00:00<?, ?it/s]"
]
},
......@@ -863,7 +863,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [00:00<00:00, 45.57it/s, loss=119]\n",
"epoch 1: 100%|██████████| 25/25 [00:00<00:00, 47.80it/s, loss=116]\n",
" 0%| | 0/25 [00:00<?, ?it/s]"
]
},
......@@ -878,7 +878,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [00:04<00:00, 6.12it/s, loss=132]\n",
"epoch 1: 100%|██████████| 25/25 [00:04<00:00, 5.94it/s, loss=132]\n",
" 0%| | 0/25 [00:00<?, ?it/s]"
]
},
......@@ -893,7 +893,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [01:06<00:00, 2.65s/it, loss=119]\n",
"epoch 1: 100%|██████████| 25/25 [01:12<00:00, 2.92s/it, loss=119]\n",
" 0%| | 0/25 [00:00<?, ?it/s]"
]
},
......@@ -908,7 +908,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [01:33<00:00, 3.72s/it, loss=108]\n",
"epoch 1: 100%|██████████| 25/25 [01:48<00:00, 4.34s/it, loss=108]\n",
" 0%| | 0/25 [00:00<?, ?it/s]"
]
},
......@@ -923,7 +923,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [01:56<00:00, 4.65s/it, loss=106]\n",
"epoch 1: 100%|██████████| 25/25 [02:05<00:00, 5.01s/it, loss=106] \n",
" 0%| | 0/25 [00:00<?, ?it/s]"
]
},
......@@ -938,7 +938,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [02:23<00:00, 5.75s/it, loss=105] \n",
"epoch 1: 100%|██████████| 25/25 [02:57<00:00, 7.11s/it, loss=105] \n",
" 0%| | 0/25 [00:00<?, ?it/s]"
]
},
......@@ -953,7 +953,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [02:53<00:00, 6.94s/it, loss=105] \n",
"epoch 1: 100%|██████████| 25/25 [03:40<00:00, 8.83s/it, loss=104] \n",
" 0%| | 0/25 [00:00<?, ?it/s]"
]
},
......@@ -968,8 +968,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [01:13<00:00, 2.92s/it, loss=120]\n",
"valid: 100%|██████████| 7/7 [00:14<00:00, 2.14s/it, loss=109] \n"
"epoch 1: 100%|██████████| 25/25 [01:20<00:00, 3.23s/it, loss=120]\n",
"valid: 100%|██████████| 7/7 [00:14<00:00, 2.06s/it, loss=109] \n"
]
}
],
......
......@@ -4,7 +4,7 @@ import pandas as pd
from pytorch_widedeep.optim import RAdam
from pytorch_widedeep.models import Wide, WideDeep, DeepDense
from pytorch_widedeep.metrics import BinaryAccuracy
from pytorch_widedeep.metrics import Accuracy, Precision
from pytorch_widedeep.callbacks import (
LRHistory,
EarlyStopping,
......@@ -76,7 +76,7 @@ if __name__ == "__main__":
EarlyStopping,
ModelCheckpoint(filepath="model_weights/wd_out"),
]
metrics = [BinaryAccuracy]
metrics = [Accuracy, Precision]
model.compile(
method="binary",
......
......@@ -8,9 +8,8 @@ from collections import Counter
import numpy as np
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer
import gender_guesser.detector as gender
from sklearn.preprocessing import MultiLabelBinarizer
warnings.filterwarnings("ignore")
......
......@@ -3,7 +3,7 @@ import torch
import pandas as pd
from pytorch_widedeep.models import Wide, WideDeep, DeepDense
from pytorch_widedeep.metrics import CategoricalAccuracy
from pytorch_widedeep.metrics import F1Score, Accuracy
from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor
use_cuda = torch.cuda.is_available()
......@@ -48,7 +48,7 @@ if __name__ == "__main__":
continuous_cols=continuous_cols,
)
model = WideDeep(wide=wide, deepdense=deepdense, pred_dim=3)
model.compile(method="multiclass", metrics=[CategoricalAccuracy])
model.compile(method="multiclass", metrics=[Accuracy, F1Score])
model.fit(
X_wide=X_wide,
......
[![Build Status](https://travis-ci.org/jrzaurin/pytorch-widedeep.svg?branch=master)](https://travis-ci.org/jrzaurin/pytorch-widedeep)
[![Documentation Status](https://readthedocs.org/projects/pytorch-widedeep/badge/?version=latest)](https://pytorch-widedeep.readthedocs.io/en/latest/?badge=latest)
# pytorch-widedeep
A flexible package to combine tabular data with text and images using wide and
......@@ -64,7 +68,7 @@ from sklearn.model_selection import train_test_split
from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor
from pytorch_widedeep.models import Wide, DeepDense, WideDeep
from pytorch_widedeep.metrics import BinaryAccuracy
from pytorch_widedeep.metrics import Accuracy
# these next 4 lines are not directly related to pytorch-widedeep. I assume
# you have downloaded the dataset and place it in a dir called data/adult/
......@@ -112,7 +116,7 @@ deepdense = DeepDense(
# build, compile and fit
model = WideDeep(wide=wide, deepdense=deepdense)
model.compile(method="binary", metrics=[BinaryAccuracy])
model.compile(method="binary", metrics=[Accuracy])
model.fit(
X_wide=X_wide,
X_deep=X_deep,
......
import numpy as np
import torch
from .wdtypes import *
from .callbacks import Callback
......@@ -46,23 +47,17 @@ class MetricCallback(Callback):
self.container.reset()
class CategoricalAccuracy(Metric):
r"""Class to calculate the categorical accuracy for multicategorical problems
class Accuracy(Metric):
r"""Class to calculate the accuracy for both binary and categorical problems
Parameters
----------
top_k: int
Accuracy will be computed using the top k most likely classes
Examples
--------
>>> y_true = torch.from_numpy(np.random.choice(3, 100))
>>> y_pred = torch.from_numpy(np.random.rand(100, 3))
>>> metric = CategoricalAccuracy(top_k=top_k)
>>> acc = metric(y_pred, y_true)
top_k: int, default = 1
Accuracy will be computed using the top k most likely classes in
multiclass problems
"""
def __init__(self, top_k=1):
def __init__(self, top_k: int = 1):
self.top_k = top_k
self.correct_count = 0
self.total_count = 0
......@@ -77,41 +72,179 @@ class CategoricalAccuracy(Metric):
self.total_count = 0
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
top_k = y_pred.topk(self.top_k, 1)[1]
true_k = y_true.view(len(y_true), 1).expand_as(top_k) # type: ignore
self.correct_count += top_k.eq(true_k).float().sum().item()
num_classes = y_pred.size(1)
if num_classes == 1:
y_pred = y_pred.round()
y_true = y_true.view(-1, 1)
elif num_classes > 1:
y_pred = y_pred.topk(self.top_k, 1)[1]
y_true = y_true.view(-1, 1).expand_as(y_pred) # type: ignore
self.correct_count += y_pred.eq(y_true).sum().item() # type: ignore
self.total_count += len(y_pred) # type: ignore
accuracy = float(self.correct_count) / float(self.total_count)
return np.round(accuracy, 4)
return accuracy
class BinaryAccuracy(Metric):
"""Class to calculate accuracy for binary classification problems
class Precision(Metric):
r"""Class to calculate the precision for both binary and categorical problems
Examples
--------
>>> y_true = torch.from_numpy(np.random.choice(2, 100)).float()
>>> y_pred = deepcopy(y_true.view(-1, 1)).float()
>>> metric = BinaryAccuracy()
>>> acc = metric(y_pred, y_true)
Parameters
----------
average: bool, default = True
This applies only to multiclass problems. if `True` calculate
precision for each label, and find their unweighted mean.
"""
def __init__(self):
self.correct_count = 0
self.total_count = 0
def __init__(self, average: bool = True):
self.average = average
self.true_positives = 0
self.all_positives = 0
self.eps = 1e-20
self._name = "acc"
self._name = "prec"
def reset(self):
def reset(self) -> None:
"""
resets counters to 0
"""
self.correct_count = 0
self.total_count = 0
self.true_positives = 0
self.all_positives = 0
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
y_pred_round = y_pred.round()
self.correct_count += y_pred_round.eq(y_true.view(-1, 1)).float().sum().item()
self.total_count += len(y_pred) # type: ignore
accuracy = float(self.correct_count) / float(self.total_count)
return np.round(accuracy, 4)
num_class = y_pred.size(1)
if num_class == 1:
y_pred = y_pred.round()
y_true = y_true.view(-1, 1)
elif num_class > 1:
y_true = torch.eye(num_class)[y_true.long()]
y_pred = y_pred.topk(1, 1)[1].view(-1)
y_pred = torch.eye(num_class)[y_pred.long()]
self.true_positives += (y_true * y_pred).sum(dim=0) # type:ignore
self.all_positives += y_pred.sum(dim=0) # type:ignore
precision = self.true_positives / (self.all_positives + self.eps)
if self.average:
return precision.mean().item() # type:ignore
else:
return precision
class Recall(Metric):
r"""Class to calculate the recall for both binary and categorical problems
Parameters
----------
average: bool, default = True
This applies only to multiclass problems. if `True` calculate recall
for each label, and find their unweighted mean.
"""
def __init__(self, average: bool = True):
self.average = average
self.true_positives = 0
self.actual_positives = 0
self.eps = 1e-20
self._name = "rec"
def reset(self) -> None:
"""
resets counters to 0
"""
self.true_positives = 0
self.actual_positives = 0
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
num_class = y_pred.size(1)
if num_class == 1:
y_pred = y_pred.round()
y_true = y_true.view(-1, 1)
elif num_class > 1:
y_true = torch.eye(num_class)[y_true.long()]
y_pred = y_pred.topk(1, 1)[1].view(-1)
y_pred = torch.eye(num_class)[y_pred.long()]
self.true_positives += (y_true * y_pred).sum(dim=0) # type: ignore
self.actual_positives += y_true.sum(dim=0) # type: ignore
recall = self.true_positives / (self.actual_positives + self.eps)
if self.average:
return recall.mean().item() # type:ignore
else:
return recall
class FBetaScore(Metric):
r"""Class to calculate the fbeta score for both binary and categorical problems
``FBeta = ((1 + beta^2) * Precision * Recall) / (beta^2 * Precision + Recall)``
Parameters
----------
beta: int
Coefficient to control the balance between precision and recall
average: bool, default = True
This applies only to multiclass problems. if `True` calculate fbeta
for each label, and find their unweighted mean.
"""
def __init__(self, beta: int, average: bool = True):
self.average = average
self.precision = Precision(average=False)
self.recall = Recall(average=False)
self.beta = beta
self._name = "".join(["f", str(beta)])
def reset(self) -> None:
"""
resets precision and recall
"""
self.precision.reset()
self.recall.reset()
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
prec = self.precision(y_pred, y_true)
rec = self.recall(y_pred, y_true)
beta2 = self.beta ** 2
fbeta = ((1 + beta2) * prec * rec) / (beta2 * prec + rec)
if self.average:
return fbeta.mean().item()
else:
return fbeta
class F1Score(Metric):
r"""Class to calculate the f1 score for both binary and categorical problems
Parameters
----------
average: bool, default = True
This applies only to multiclass problems. if `True` calculate f1 for
each label, and find their unweighted mean.
"""
def __init__(self, average: bool = True):
self.f1 = FBetaScore(beta=1, average=average)
self._name = self.f1._name
def reset(self) -> None:
"""
resets counters to 0
"""
self.f1.reset()
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
return self.f1(y_pred, y_true)
......@@ -266,9 +266,9 @@ class WideDeep(nn.Module):
See the ``Callbacks`` section in this documentation or
:obj:`pytorch_widedeep.callbacks`
metrics: List[Metric], Optional. Default=None
Metrics available are: ``BinaryAccuracy`` and
``CategoricalAccuracy`` See the ``Metrics`` section in this
documentation or :obj:`pytorch_widedeep.metrics`
Metrics available are: ``Accuracy``, ``Precision``, ``Recall``,
``FBetaScore`` and ``F1Score``. See the ``Metrics`` section in
this documentation or :obj:`pytorch_widedeep.metrics`
class_weight: Union[float, List[float], Tuple[float]]. Optional. Default=None
- float indicating the weight of the minority class in binary classification
problems (e.g. 9.)
......@@ -587,17 +587,22 @@ class WideDeep(nn.Module):
with trange(train_steps, disable=self.verbose != 1) as t:
for batch_idx, (data, target) in zip(t, train_loader):
t.set_description("epoch %i" % (epoch + 1))
acc, train_loss = self._training_step(data, target, batch_idx)
if acc is not None:
t.set_postfix(metrics=acc, loss=train_loss)
score, train_loss = self._training_step(data, target, batch_idx)
if score is not None:
t.set_postfix(
metrics={k: np.round(v, 4) for k, v in score.items()},
loss=train_loss,
)
else:
t.set_postfix(loss=np.sqrt(train_loss))
if self.lr_scheduler:
self._lr_scheduler_step(step_location="on_batch_end")
self.callback_container.on_batch_end(batch=batch_idx)
epoch_logs["train_loss"] = train_loss
if acc is not None:
epoch_logs["train_acc"] = acc["acc"]
if score is not None:
for k, v in score.items():
log_k = "_".join(["train", k])
epoch_logs[log_k] = v
# eval step...
if epoch % validation_freq == (validation_freq - 1):
if eval_set is not None:
......@@ -612,14 +617,21 @@ class WideDeep(nn.Module):
with trange(eval_steps, disable=self.verbose != 1) as v:
for i, (data, target) in zip(v, eval_loader):
v.set_description("valid")
acc, val_loss = self._validation_step(data, target, i)
if acc is not None:
v.set_postfix(metrics=acc, loss=val_loss)
score, val_loss = self._validation_step(data, target, i)
if score is not None:
v.set_postfix(
metrics={
k: np.round(v, 4) for k, v in score.items()
},
loss=val_loss,
)
else:
v.set_postfix(loss=np.sqrt(val_loss))
epoch_logs["val_loss"] = val_loss
if acc is not None:
epoch_logs["val_acc"] = acc["acc"]
if score is not None:
for k, v in score.items():
log_k = "_".join(["val", k])
epoch_logs[log_k] = v
if self.lr_scheduler:
self._lr_scheduler_step(step_location="on_epoch_end")
#  log and check if early_stop...
......@@ -986,10 +998,10 @@ class WideDeep(nn.Module):
if self.metric is not None:
if self.method == "binary":
acc = self.metric(torch.sigmoid(y_pred), y)
score = self.metric(torch.sigmoid(y_pred), y)
if self.method == "multiclass":
acc = self.metric(F.softmax(y_pred, dim=1), y)
return acc, avg_loss
score = self.metric(F.softmax(y_pred, dim=1), y)
return score, avg_loss
else:
return None, avg_loss
......@@ -1008,10 +1020,10 @@ class WideDeep(nn.Module):
if self.metric is not None:
if self.method == "binary":
acc = self.metric(torch.sigmoid(y_pred), y)
score = self.metric(torch.sigmoid(y_pred), y)
if self.method == "multiclass":
acc = self.metric(F.softmax(y_pred, dim=1), y)
return acc, avg_loss
score = self.metric(F.softmax(y_pred, dim=1), y)
return score, avg_loss
else:
return None, avg_loss
......
__version__ = "0.4.1"
__version__ = "0.4.2"
......@@ -3,23 +3,113 @@ from copy import deepcopy
import numpy as np
import torch
import pytest
from sklearn.metrics import (
f1_score,
fbeta_score,
recall_score,
accuracy_score,
precision_score,
)
from pytorch_widedeep.metrics import BinaryAccuracy, CategoricalAccuracy
from pytorch_widedeep.metrics import (
Recall,
F1Score,
Accuracy,
Precision,
FBetaScore,
)
y_true = torch.from_numpy(np.random.choice(2, 100)).float()
y_pred = deepcopy(y_true.view(-1, 1)).float()
def f2_score_bin(y_true, y_pred):
return fbeta_score(y_true, y_pred, beta=2)
def test_binary_accuracy():
metric = BinaryAccuracy()
acc = metric(y_pred, y_true)
assert acc == 1.0
y_true_bin_np = np.array([1, 0, 0, 0, 1, 1, 0])
y_pred_bin_np = np.array([0.6, 0.3, 0.2, 0.8, 0.4, 0.9, 0.6])
y_true_bin_pt = torch.from_numpy(y_true_bin_np)
y_pred_bin_pt = torch.from_numpy(y_pred_bin_np).view(-1, 1)
###############################################################################
# Test binary metrics
###############################################################################
@pytest.mark.parametrize(
"sklearn_metric, widedeep_metric",
[
(accuracy_score, Accuracy()),
(precision_score, Precision()),
(recall_score, Recall()),
(f1_score, F1Score()),
(f2_score_bin, FBetaScore(beta=2)),
],
)
def test_binary_metrics(sklearn_metric, widedeep_metric):
assert np.isclose(
sklearn_metric(y_true_bin_np, y_pred_bin_np.round()),
widedeep_metric(y_pred_bin_pt, y_true_bin_pt),
)
###############################################################################
# Test top_k for Accuracy
###############################################################################
@pytest.mark.parametrize("top_k, expected_acc", [(1, 0.33), (2, 0.66)])
def test_categorical_accuracy(top_k, expected_acc):
def test_categorical_accuracy_topk(top_k, expected_acc):
y_true = torch.from_numpy(np.random.choice(3, 100))
y_pred = torch.from_numpy(np.random.rand(100, 3))
metric = CategoricalAccuracy(top_k=top_k)
metric = Accuracy(top_k=top_k)
acc = metric(y_pred, y_true)
assert np.isclose(acc, expected_acc, atol=0.3)
###############################################################################
# Test multiclass metrics
###############################################################################
y_true_multi_np = np.array([1, 0, 2, 1, 1, 2, 2, 0, 0, 0])
y_pred_muli_np = np.array(
[
[0.2, 0.6, 0.2],
[0.4, 0.5, 0.1],
[0.1, 0.1, 0.8],
[0.1, 0.6, 0.3],
[0.1, 0.8, 0.1],
[0.1, 0.6, 0.6],
[0.2, 0.6, 0.8],
[0.6, 0.1, 0.3],
[0.7, 0.2, 0.1],
[0.1, 0.7, 0.2],
]
)
y_true_multi_pt = torch.from_numpy(y_true_multi_np)
y_pred_multi_pt = torch.from_numpy(y_pred_muli_np)
def f2_score_multi(y_true, y_pred, average):
return fbeta_score(y_true, y_pred, average=average, beta=2)
@pytest.mark.parametrize(
"sklearn_metric, widedeep_metric",
[
(accuracy_score, Accuracy()),
(precision_score, Precision()),
(recall_score, Recall()),
(f1_score, F1Score()),
(f2_score_multi, FBetaScore(beta=2)),
],
)
def test_muticlass_metrics(sklearn_metric, widedeep_metric):
if sklearn_metric.__name__ == "accuracy_score":
assert np.isclose(
sklearn_metric(y_true_multi_np, y_pred_muli_np.argmax(axis=1)),
widedeep_metric(y_pred_multi_pt, y_true_multi_pt),
)
else:
assert np.isclose(
sklearn_metric(
y_true_multi_np, y_pred_muli_np.argmax(axis=1), average="macro"
),
widedeep_metric(y_pred_multi_pt, y_true_multi_pt),
)
......@@ -9,7 +9,7 @@ from sklearn.utils import Bunch
from torch.utils.data import Dataset, DataLoader
from pytorch_widedeep.models import Wide, DeepDense
from pytorch_widedeep.metrics import BinaryAccuracy
from pytorch_widedeep.metrics import Accuracy
from pytorch_widedeep.models._warmup import WarmUp
from pytorch_widedeep.models.deep_image import conv_layer
......@@ -138,7 +138,7 @@ wdset = WDset(X_wide, X_deep, X_text, X_image, target)
wdloader = DataLoader(wdset, batch_size=10, shuffle=True)
# Instantiate the WarmUp class
warmer = WarmUp(loss_fn, BinaryAccuracy(), "binary", False)
warmer = WarmUp(loss_fn, Accuracy(), "binary", False)
# List the layers for the warm_gradual method
text_layers = [c for c in list(deeptext.children())[1:]][::-1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册