Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Pytorch Widedeep
提交
65465a45
P
Pytorch Widedeep
项目概览
Greenplum
/
Pytorch Widedeep
11 个月 前同步成功
通知
9
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Pytorch Widedeep
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
65465a45
编写于
7月 21, 2020
作者:
J
Javier
提交者:
GitHub
7月 21, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #17 from jrzaurin/precision_recall
Precision recall
上级
31c2d8ef
393ea436
变更
21
展开全部
隐藏空白更改
内联
并排
Showing
21 changed file
with
573 addition
and
308 deletion
+573
-308
README.md
README.md
+5
-3
VERSION
VERSION
+1
-1
code_style.sh
code_style.sh
+1
-1
docs/examples.rst
docs/examples.rst
+13
-0
docs/index.rst
docs/index.rst
+1
-0
docs/quick_start.rst
docs/quick_start.rst
+2
-2
docs/wide_deep/metrics.rst
docs/wide_deep/metrics.rst
+17
-2
examples/01_Preprocessors_and_utils.ipynb
examples/01_Preprocessors_and_utils.ipynb
+40
-40
examples/02_Model_Components.ipynb
examples/02_Model_Components.ipynb
+9
-9
examples/03_Binary_Classification_with_Defaults.ipynb
examples/03_Binary_Classification_with_Defaults.ipynb
+26
-26
examples/04_Binary_Classification_Varying_Parameters.ipynb
examples/04_Binary_Classification_Varying_Parameters.ipynb
+111
-115
examples/06_WarmUp_Model_Components.ipynb
examples/06_WarmUp_Model_Components.ipynb
+34
-34
examples/adult_script.py
examples/adult_script.py
+2
-2
examples/airbnb_data_preprocessing.py
examples/airbnb_data_preprocessing.py
+1
-2
examples/airbnb_script_multiclass.py
examples/airbnb_script_multiclass.py
+2
-2
pypi_README.md
pypi_README.md
+6
-2
pytorch_widedeep/metrics.py
pytorch_widedeep/metrics.py
+169
-36
pytorch_widedeep/models/wide_deep.py
pytorch_widedeep/models/wide_deep.py
+31
-19
pytorch_widedeep/version.py
pytorch_widedeep/version.py
+1
-1
tests/test_model_functioning/test_metrics.py
tests/test_model_functioning/test_metrics.py
+99
-9
tests/test_warm_up/test_warm_up_routines.py
tests/test_warm_up/test_warm_up_routines.py
+2
-2
未找到文件。
README.md
浏览文件 @
65465a45
...
...
@@ -40,7 +40,9 @@ final output neuron or neurons, depending on whether we are performing a
binary classification or regression, or a multi-class classification. The
components within the faded-pink rectangles are concatenated.
In math terms, and following the notation in the
[
paper
](
https://arxiv.org/abs/1606.07792
)
, Architecture 1 can be formulated as:
In math terms, and following the notation in the
[
paper
](
https://arxiv.org/abs/1606.07792
)
, Architecture 1 can be formulated
as:
<p
align=
"center"
>
<img
width=
"500"
src=
"docs/figures/architecture_1_math.png"
>
...
...
@@ -130,7 +132,7 @@ from sklearn.model_selection import train_test_split
from
pytorch_widedeep.preprocessing
import
WidePreprocessor
,
DensePreprocessor
from
pytorch_widedeep.models
import
Wide
,
DeepDense
,
WideDeep
from
pytorch_widedeep.metrics
import
Binary
Accuracy
from
pytorch_widedeep.metrics
import
Accuracy
# these next 4 lines are not directly related to pytorch-widedeep. I assume
# you have downloaded the dataset and place it in a dir called data/adult/
...
...
@@ -178,7 +180,7 @@ deepdense = DeepDense(
# build, compile and fit
model
=
WideDeep
(
wide
=
wide
,
deepdense
=
deepdense
)
model
.
compile
(
method
=
"binary"
,
metrics
=
[
Binary
Accuracy
])
model
.
compile
(
method
=
"binary"
,
metrics
=
[
Accuracy
])
model
.
fit
(
X_wide
=
X_wide
,
X_deep
=
X_deep
,
...
...
VERSION
浏览文件 @
65465a45
0.4.1
\ No newline at end of file
0.4.2
\ No newline at end of file
code_style.sh
浏览文件 @
65465a45
# sort imports
isort
--recursive
.
pytorch_widedeep tests examples setup.py
isort
.
pytorch_widedeep tests examples setup.py
# Black code style
black
.
pytorch_widedeep tests examples setup.py
# flake8 standards
...
...
docs/examples.rst
0 → 100644
浏览文件 @
65465a45
pytorch-widedeep Examples
*****************************
This section provides links to example notebooks that may be helpful to better
understand the functionalities withing ``pytorch-widedeep`` and how to use
them to address different problems
* `Preprocessors and Utils <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/01_Preprocessors_and_utils.ipynb>`__
* `Model Components <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/02_Model_Components.ipynb>`__
* `Binary Classification with default parameters <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/03_Binary_Classification_with_Defaults.ipynb>`__
* `Binary Classification with varying parameters <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/04_Binary_Classification_Varying_Parameters.ipynb>`__
* `Regression with Images and Text <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/05_Regression_with_Images_and_Text.ipynb>`__
* `Warm up routines <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/06_WarmUp_Model_Components.ipynb>`__
docs/index.rst
浏览文件 @
65465a45
...
...
@@ -19,6 +19,7 @@ Documentation
Preprocessing <preprocessing>
Model Components <model_components>
Wide and Deep Models <wide_deep/index>
Examples <examples>
Introduction
...
...
docs/quick_start.rst
浏览文件 @
65465a45
...
...
@@ -30,7 +30,7 @@ Prepare the wide and deep columns
from
pytorch_widedeep
.
preprocessing
import
WidePreprocessor
,
DensePreprocessor
from
pytorch_widedeep
.
models
import
Wide
,
DeepDense
,
WideDeep
from
pytorch_widedeep
.
metrics
import
Binary
Accuracy
from
pytorch_widedeep
.
metrics
import
Accuracy
#
prepare
wide
,
crossed
,
embedding
and
continuous
columns
wide_cols
=
[
...
...
@@ -83,7 +83,7 @@ Build, compile, fit and predict
#
build
,
compile
and
fit
model
=
WideDeep
(
wide
=
wide
,
deepdense
=
deepdense
)
model
.
compile
(
method
=
"binary"
,
metrics
=[
Binary
Accuracy
])
model
.
compile
(
method
=
"binary"
,
metrics
=[
Accuracy
])
model
.
fit
(
X_wide
=
X_wide
,
X_deep
=
X_deep
,
...
...
docs/wide_deep/metrics.rst
浏览文件 @
65465a45
Metrics
=======
.. autoclass:: pytorch_widedeep.metrics.
Binary
Accuracy
.. autoclass:: pytorch_widedeep.metrics.Accuracy
:members:
:undoc-members:
:show-inheritance:
.. autoclass:: pytorch_widedeep.metrics.CategoricalAccuracy
.. autoclass:: pytorch_widedeep.metrics.Precision
:members:
:undoc-members:
:show-inheritance:
.. autoclass:: pytorch_widedeep.metrics.Recall
:members:
:undoc-members:
:show-inheritance:
.. autoclass:: pytorch_widedeep.metrics.FBetaScore
:members:
:undoc-members:
:show-inheritance:
.. autoclass:: pytorch_widedeep.metrics.F1Score
:members:
:undoc-members:
:show-inheritance:
examples/01_Preprocessors_and_utils.ipynb
浏览文件 @
65465a45
此差异已折叠。
点击以展开。
examples/02_Model_Components.ipynb
浏览文件 @
65465a45
...
...
@@ -170,11 +170,11 @@
{
"data": {
"text/plain": [
"tensor([[-0.0000, -
1.0061, -0.0000, -0.9828, -0.0000, -0.0000, -0.9944, -1.0133
],\n",
" [
-0.0000, -0.9996, 0.0000, -1.0374, 0.0000, -0.0000, -1.0313, -0.0000
],\n",
" [-0.
8576, -1.0017, -0.0000, -0.9881, -0.0000, 0.0000, -0.0000, -0.0000
],\n",
" [
3.9816, 0.0000, 0.0000, 0.0000, 3.7309, 1.1728, 0.0000, -1.1160
],\n",
" [-
1.1339, -0.9925, -0.0000, -0.0000, -0.0000, 0.0000, -0.9638, 0.0000
]],\n",
"tensor([[-0.0000, -
0.9949, 3.8273, 0.0000, -1.3889, -2.9641, 0.0000, -0.0000
],\n",
" [
3.9123, -0.0000, -0.0000, 1.9555, -1.3561, 1.7069, -0.0000, 0.9275
],\n",
" [-0.
0000, -0.0000, 0.0000, -0.0000, 0.0000, -1.6489, -0.0000, -1.4985
],\n",
" [
-1.2736, 0.0000, -1.2819, 2.1232, 0.0000, 2.2767, -0.0000, 3.5354
],\n",
" [-
0.1726, -0.0000, -1.3275, -0.0000, -1.3703, 0.0000, -0.0000, -1.4637
]],\n",
" grad_fn=<MulBackward0>)"
]
},
...
...
@@ -484,10 +484,10 @@
{
"data": {
"text/plain": [
"tensor([[-
1.4630e-04, -6.1540e-04, -2.4541e-04, 2.7543e-01, 1.2993e-01
,\n",
" -
1.6553e-03, 6.7002e-02, 2.3974e-01
],\n",
" [-
9.9619e-04, -1.9412e-03, 1.2113e-01, 1.0122e-01, 2.9080e-01
,\n",
" -
2.0852e-03, -1.8016e-04, 2.7996e-02
]], grad_fn=<LeakyReluBackward1>)"
"tensor([[-
2.2825e-03, -8.3100e-04, -8.8423e-04, -1.1084e-04, 8.8529e-02
,\n",
" -
5.1577e-04, 2.8343e-01, -1.7071e-03
],\n",
" [-
1.8486e-03, -8.5602e-04, -1.8552e-03, 3.6481e-01, 9.0812e-02
,\n",
" -
9.6603e-04, 3.9017e-01, -2.6355e-03
]], grad_fn=<LeakyReluBackward1>)"
]
},
"execution_count": 18,
...
...
examples/03_Binary_Classification_with_Defaults.ipynb
浏览文件 @
65465a45
...
...
@@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count":
5
,
"execution_count":
1
,
"metadata": {},
"outputs": [],
"source": [
...
...
@@ -21,12 +21,12 @@
"\n",
"from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor\n",
"from pytorch_widedeep.models import Wide, DeepDense, WideDeep\n",
"from pytorch_widedeep.metrics import
BinaryAccuracy
"
"from pytorch_widedeep.metrics import
Accuracy, Precision
"
]
},
{
"cell_type": "code",
"execution_count":
6
,
"execution_count":
2
,
"metadata": {},
"outputs": [
{
...
...
@@ -185,7 +185,7 @@
"4 30 United-States <=50K "
]
},
"execution_count":
6
,
"execution_count":
2
,
"metadata": {},
"output_type": "execute_result"
}
...
...
@@ -197,7 +197,7 @@
},
{
"cell_type": "code",
"execution_count":
7
,
"execution_count":
3
,
"metadata": {},
"outputs": [
{
...
...
@@ -356,7 +356,7 @@
"4 30 United-States 0 "
]
},
"execution_count":
7
,
"execution_count":
3
,
"metadata": {},
"output_type": "execute_result"
}
...
...
@@ -381,7 +381,7 @@
},
{
"cell_type": "code",
"execution_count":
8
,
"execution_count":
4
,
"metadata": {},
"outputs": [],
"source": [
...
...
@@ -394,7 +394,7 @@
},
{
"cell_type": "code",
"execution_count":
9
,
"execution_count":
5
,
"metadata": {},
"outputs": [],
"source": [
...
...
@@ -412,7 +412,7 @@
},
{
"cell_type": "code",
"execution_count":
10
,
"execution_count":
6
,
"metadata": {},
"outputs": [
{
...
...
@@ -437,7 +437,7 @@
},
{
"cell_type": "code",
"execution_count":
11
,
"execution_count":
7
,
"metadata": {},
"outputs": [
{
...
...
@@ -475,7 +475,7 @@
},
{
"cell_type": "code",
"execution_count":
14
,
"execution_count":
8
,
"metadata": {},
"outputs": [],
"source": [
...
...
@@ -489,7 +489,7 @@
},
{
"cell_type": "code",
"execution_count":
15
,
"execution_count":
9
,
"metadata": {},
"outputs": [
{
...
...
@@ -527,7 +527,7 @@
")"
]
},
"execution_count":
15
,
"execution_count":
9
,
"metadata": {},
"output_type": "execute_result"
}
...
...
@@ -560,16 +560,16 @@
},
{
"cell_type": "code",
"execution_count": 1
6
,
"execution_count": 1
0
,
"metadata": {},
"outputs": [],
"source": [
"model.compile(method='binary', metrics=[
BinaryAccuracy
])"
"model.compile(method='binary', metrics=[
Accuracy, Precision
])"
]
},
{
"cell_type": "code",
"execution_count": 1
7
,
"execution_count": 1
1
,
"metadata": {},
"outputs": [
{
...
...
@@ -591,16 +591,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 153/153 [00:0
2<00:00, 64.79it/s, loss=0.435, metrics={'acc': 0.7901
}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00,
124.97it/s, loss=0.358, metrics={'acc': 0.799}]
\n",
"epoch 2: 100%|██████████| 153/153 [00:0
2<00:00, 71.36it/s, loss=0.352, metrics={'acc': 0.8352}]
\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 1
24.33it/s, loss=0.349, metrics={'acc': 0.8358
}]\n",
"epoch 3: 100%|██████████| 153/153 [00:0
2<00:00, 72.24it/s, loss=0.345, metrics={'acc': 0.8383}]
\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 1
21.07it/s, loss=0.345, metrics={'acc': 0.8389
}]\n",
"epoch 4: 100%|██████████| 153/153 [00:0
2<00:00, 70.39it/s, loss=0.341, metrics={'acc': 0.8404}]
\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 1
23.29it/s, loss=0.343, metrics={'acc': 0.8406
}]\n",
"epoch 5: 100%|██████████| 153/153 [00:0
2<00:00, 71.14it/s, loss=0.339, metrics={'acc': 0.8423
}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 1
21.12it/s, loss=0.342, metrics={'acc': 0.8426
}]\n"
"epoch 1: 100%|██████████| 153/153 [00:0
1<00:00, 102.41it/s, loss=0.585, metrics={'acc': 0.7512, 'prec': 0.1818
}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00,
98.78it/s, loss=0.513, metrics={'acc': 0.754, 'prec': 0.2429}]
\n",
"epoch 2: 100%|██████████| 153/153 [00:0
1<00:00, 117.30it/s, loss=0.481, metrics={'acc': 0.782, 'prec': 0.8287}]
\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 1
06.49it/s, loss=0.454, metrics={'acc': 0.7866, 'prec': 0.8245
}]\n",
"epoch 3: 100%|██████████| 153/153 [00:0
1<00:00, 124.78it/s, loss=0.44, metrics={'acc': 0.8055, 'prec': 0.781}]
\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 1
15.36it/s, loss=0.425, metrics={'acc': 0.8077, 'prec': 0.7818
}]\n",
"epoch 4: 100%|██████████| 153/153 [00:0
1<00:00, 125.01it/s, loss=0.418, metrics={'acc': 0.814, 'prec': 0.7661}]
\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 1
14.92it/s, loss=0.408, metrics={'acc': 0.8149, 'prec': 0.7671
}]\n",
"epoch 5: 100%|██████████| 153/153 [00:0
1<00:00, 116.57it/s, loss=0.404, metrics={'acc': 0.819, 'prec': 0.7527
}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 1
08.89it/s, loss=0.397, metrics={'acc': 0.8203, 'prec': 0.7547
}]\n"
]
}
],
...
...
examples/04_Binary_Classification_Varying_Parameters.ipynb
浏览文件 @
65465a45
此差异已折叠。
点击以展开。
examples/06_WarmUp_Model_Components.ipynb
浏览文件 @
65465a45
...
...
@@ -43,7 +43,7 @@
"\n",
"from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor\n",
"from pytorch_widedeep.models import Wide, DeepDense, WideDeep\n",
"from pytorch_widedeep.metrics import
Binary
Accuracy"
"from pytorch_widedeep.metrics import Accuracy"
]
},
{
...
...
@@ -273,7 +273,7 @@
"metadata": {},
"outputs": [],
"source": [
"model.compile(method='binary', metrics=[
Binary
Accuracy])"
"model.compile(method='binary', metrics=[Accuracy])"
]
},
{
...
...
@@ -307,11 +307,11 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 153/153 [00:01<00:00, 1
18.14it/s, loss=0.475, metrics={'acc': 0.7854
}]\n",
"epoch 2: 100%|██████████| 153/153 [00:0
0<00:00, 154.41it/s, loss=0.373, metrics={'acc': 0.8069
}]\n",
"epoch 3: 100%|██████████| 153/153 [00:0
0<00:00, 153.93it/s, loss=0.365, metrics={'acc': 0.8151
}]\n",
"epoch 4: 100%|██████████| 153/153 [00:0
0<00:00, 154.42it/s, loss=0.362, metrics={'acc': 0.8193
}]\n",
"epoch 5: 100%|██████████| 153/153 [00:0
1<00:00, 147.62it/s, loss=0.36, metrics={'acc': 0.8219
}]\n",
"epoch 1: 100%|██████████| 153/153 [00:01<00:00, 1
27.54it/s, loss=0.476, metrics={'acc': 0.7808972948071559
}]\n",
"epoch 2: 100%|██████████| 153/153 [00:0
1<00:00, 126.88it/s, loss=0.373, metrics={'acc': 0.8048268625393494
}]\n",
"epoch 3: 100%|██████████| 153/153 [00:0
1<00:00, 141.92it/s, loss=0.365, metrics={'acc': 0.8136820822562895
}]\n",
"epoch 4: 100%|██████████| 153/153 [00:0
1<00:00, 151.56it/s, loss=0.362, metrics={'acc': 0.8182312594374632
}]\n",
"epoch 5: 100%|██████████| 153/153 [00:0
0<00:00, 158.22it/s, loss=0.36, metrics={'acc': 0.8210477823561027
}]\n",
" 0%| | 0/153 [00:00<?, ?it/s]"
]
},
...
...
@@ -326,11 +326,11 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 153/153 [00:0
2<00:00, 75.79it/s, loss=0.392, metrics={'acc': 0.8209
}]\n",
"epoch 2: 100%|██████████| 153/153 [00:0
1<00:00, 76.97it/s, loss=0.35, metrics={'acc': 0.823}]
\n",
"epoch 3: 100%|██████████| 153/153 [00:02<00:00, 7
5.86it/s, loss=0.344, metrics={'acc': 0.8251
}]\n",
"epoch 4: 100%|██████████| 153/153 [00:01<00:00, 76.79it/s, loss=0.3
4, metrics={'acc': 0.8269}]
\n",
"epoch 5: 100%|██████████| 153/153 [00:01<00:00, 7
7.39it/s, loss=0.335, metrics={'acc': 0.8286
}]\n",
"epoch 1: 100%|██████████| 153/153 [00:0
1<00:00, 78.65it/s, loss=0.397, metrics={'acc': 0.8198073691125158
}]\n",
"epoch 2: 100%|██████████| 153/153 [00:0
2<00:00, 75.69it/s, loss=0.348, metrics={'acc': 0.8221936229255862}]
\n",
"epoch 3: 100%|██████████| 153/153 [00:02<00:00, 7
4.79it/s, loss=0.343, metrics={'acc': 0.8243576126737133
}]\n",
"epoch 4: 100%|██████████| 153/153 [00:01<00:00, 76.79it/s, loss=0.3
38, metrics={'acc': 0.8264502057402526}]
\n",
"epoch 5: 100%|██████████| 153/153 [00:01<00:00, 7
9.57it/s, loss=0.334, metrics={'acc': 0.8283059913495252
}]\n",
" 0%| | 0/153 [00:00<?, ?it/s]"
]
},
...
...
@@ -345,16 +345,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 153/153 [00:01<00:00, 114.
7
0it/s, loss=0.36, metrics={'acc': 0.8323}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 12
0.50
it/s, loss=0.364, metrics={'acc': 0.8325}]\n",
"epoch 2: 100%|██████████| 153/153 [00:01<00:00, 11
2.75it/s, loss=0.359, metrics={'acc': 0.8324
}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 1
19.57it/s, loss=0.364, metrics={'acc': 0.8326
}]\n",
"epoch 3: 100%|██████████| 153/153 [00:01<00:00, 11
4.84it/s, loss=0.359, metrics={'acc': 0.8323
}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 1
21.03it/s, loss=0.363, metrics={'acc': 0.8326
}]\n",
"epoch 4: 100%|██████████| 153/153 [00:01<00:00, 11
4.46it/s, loss=0.359, metrics={'acc': 0.8324
}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 12
2.44it/s, loss=0.363, metrics={'acc': 0.8327
}]\n",
"epoch 5: 100%|██████████| 153/153 [00:01<00:00, 114.
27it/s, loss=0.358, metrics={'acc': 0.833}]
\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 1
19.25it/s, loss=0.363, metrics={'acc': 0.8332
}]\n"
"epoch 1: 100%|██████████| 153/153 [00:01<00:00, 114.
1
0it/s, loss=0.36, metrics={'acc': 0.8323}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 12
3.16
it/s, loss=0.364, metrics={'acc': 0.8325}]\n",
"epoch 2: 100%|██████████| 153/153 [00:01<00:00, 11
3.50it/s, loss=0.359, metrics={'acc': 0.8325
}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 1
22.56it/s, loss=0.364, metrics={'acc': 0.8327
}]\n",
"epoch 3: 100%|██████████| 153/153 [00:01<00:00, 11
0.90it/s, loss=0.359, metrics={'acc': 0.8325
}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 1
19.56it/s, loss=0.363, metrics={'acc': 0.8327
}]\n",
"epoch 4: 100%|██████████| 153/153 [00:01<00:00, 11
2.92it/s, loss=0.359, metrics={'acc': 0.8326
}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 12
1.00it/s, loss=0.363, metrics={'acc': 0.8329
}]\n",
"epoch 5: 100%|██████████| 153/153 [00:01<00:00, 114.
15it/s, loss=0.358, metrics={'acc': 0.8327}]
\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 1
08.91it/s, loss=0.363, metrics={'acc': 0.8329
}]\n"
]
}
],
...
...
@@ -450,7 +450,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 8%|▊ |
78/1001 [00:00<00:02, 387.42
it/s]"
" 8%|▊ |
84/1001 [00:00<00:02, 416.73
it/s]"
]
},
{
...
...
@@ -464,7 +464,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1001/1001 [00:02<00:00,
392.18
it/s]\n"
"100%|██████████| 1001/1001 [00:02<00:00,
400.82
it/s]\n"
]
},
{
...
...
@@ -848,7 +848,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [00:00<00:00, 58.0
9it/s, loss=128
]\n",
"epoch 1: 100%|██████████| 25/25 [00:00<00:00, 58.0
3it/s, loss=127
]\n",
" 0%| | 0/25 [00:00<?, ?it/s]"
]
},
...
...
@@ -863,7 +863,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [00:00<00:00, 4
5.57it/s, loss=119
]\n",
"epoch 1: 100%|██████████| 25/25 [00:00<00:00, 4
7.80it/s, loss=116
]\n",
" 0%| | 0/25 [00:00<?, ?it/s]"
]
},
...
...
@@ -878,7 +878,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [00:04<00:00,
6.12
it/s, loss=132]\n",
"epoch 1: 100%|██████████| 25/25 [00:04<00:00,
5.94
it/s, loss=132]\n",
" 0%| | 0/25 [00:00<?, ?it/s]"
]
},
...
...
@@ -893,7 +893,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [01:
06<00:00, 2.65
s/it, loss=119]\n",
"epoch 1: 100%|██████████| 25/25 [01:
12<00:00, 2.92
s/it, loss=119]\n",
" 0%| | 0/25 [00:00<?, ?it/s]"
]
},
...
...
@@ -908,7 +908,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [01:
33<00:00, 3.72
s/it, loss=108]\n",
"epoch 1: 100%|██████████| 25/25 [01:
48<00:00, 4.34
s/it, loss=108]\n",
" 0%| | 0/25 [00:00<?, ?it/s]"
]
},
...
...
@@ -923,7 +923,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [0
1:56<00:00, 4.65s/it, loss=106]
\n",
"epoch 1: 100%|██████████| 25/25 [0
2:05<00:00, 5.01s/it, loss=106]
\n",
" 0%| | 0/25 [00:00<?, ?it/s]"
]
},
...
...
@@ -938,7 +938,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [02:
23<00:00, 5.75
s/it, loss=105] \n",
"epoch 1: 100%|██████████| 25/25 [02:
57<00:00, 7.11
s/it, loss=105] \n",
" 0%| | 0/25 [00:00<?, ?it/s]"
]
},
...
...
@@ -953,7 +953,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [0
2:53<00:00, 6.94s/it, loss=105
] \n",
"epoch 1: 100%|██████████| 25/25 [0
3:40<00:00, 8.83s/it, loss=104
] \n",
" 0%| | 0/25 [00:00<?, ?it/s]"
]
},
...
...
@@ -968,8 +968,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [01:
13<00:00, 2.92
s/it, loss=120]\n",
"valid: 100%|██████████| 7/7 [00:14<00:00, 2.
14
s/it, loss=109] \n"
"epoch 1: 100%|██████████| 25/25 [01:
20<00:00, 3.23
s/it, loss=120]\n",
"valid: 100%|██████████| 7/7 [00:14<00:00, 2.
06
s/it, loss=109] \n"
]
}
],
...
...
examples/adult_script.py
浏览文件 @
65465a45
...
...
@@ -4,7 +4,7 @@ import pandas as pd
from
pytorch_widedeep.optim
import
RAdam
from
pytorch_widedeep.models
import
Wide
,
WideDeep
,
DeepDense
from
pytorch_widedeep.metrics
import
BinaryAccuracy
from
pytorch_widedeep.metrics
import
Accuracy
,
Precision
from
pytorch_widedeep.callbacks
import
(
LRHistory
,
EarlyStopping
,
...
...
@@ -76,7 +76,7 @@ if __name__ == "__main__":
EarlyStopping
,
ModelCheckpoint
(
filepath
=
"model_weights/wd_out"
),
]
metrics
=
[
BinaryAccuracy
]
metrics
=
[
Accuracy
,
Precision
]
model
.
compile
(
method
=
"binary"
,
...
...
examples/airbnb_data_preprocessing.py
浏览文件 @
65465a45
...
...
@@ -8,9 +8,8 @@ from collections import Counter
import
numpy
as
np
import
pandas
as
pd
from
sklearn.preprocessing
import
MultiLabelBinarizer
import
gender_guesser.detector
as
gender
from
sklearn.preprocessing
import
MultiLabelBinarizer
warnings
.
filterwarnings
(
"ignore"
)
...
...
examples/airbnb_script_multiclass.py
浏览文件 @
65465a45
...
...
@@ -3,7 +3,7 @@ import torch
import
pandas
as
pd
from
pytorch_widedeep.models
import
Wide
,
WideDeep
,
DeepDense
from
pytorch_widedeep.metrics
import
Categorical
Accuracy
from
pytorch_widedeep.metrics
import
F1Score
,
Accuracy
from
pytorch_widedeep.preprocessing
import
WidePreprocessor
,
DensePreprocessor
use_cuda
=
torch
.
cuda
.
is_available
()
...
...
@@ -48,7 +48,7 @@ if __name__ == "__main__":
continuous_cols
=
continuous_cols
,
)
model
=
WideDeep
(
wide
=
wide
,
deepdense
=
deepdense
,
pred_dim
=
3
)
model
.
compile
(
method
=
"multiclass"
,
metrics
=
[
CategoricalAccuracy
])
model
.
compile
(
method
=
"multiclass"
,
metrics
=
[
Accuracy
,
F1Score
])
model
.
fit
(
X_wide
=
X_wide
,
...
...
pypi_README.md
浏览文件 @
65465a45
[
![Build Status
](
https://travis-ci.org/jrzaurin/pytorch-widedeep.svg?branch=master
)
](https://travis-ci.org/jrzaurin/pytorch-widedeep)
[
![Documentation Status
](
https://readthedocs.org/projects/pytorch-widedeep/badge/?version=latest
)
](https://pytorch-widedeep.readthedocs.io/en/latest/?badge=latest)
# pytorch-widedeep
A flexible package to combine tabular data with text and images using wide and
...
...
@@ -64,7 +68,7 @@ from sklearn.model_selection import train_test_split
from
pytorch_widedeep.preprocessing
import
WidePreprocessor
,
DensePreprocessor
from
pytorch_widedeep.models
import
Wide
,
DeepDense
,
WideDeep
from
pytorch_widedeep.metrics
import
Binary
Accuracy
from
pytorch_widedeep.metrics
import
Accuracy
# these next 4 lines are not directly related to pytorch-widedeep. I assume
# you have downloaded the dataset and place it in a dir called data/adult/
...
...
@@ -112,7 +116,7 @@ deepdense = DeepDense(
# build, compile and fit
model
=
WideDeep
(
wide
=
wide
,
deepdense
=
deepdense
)
model
.
compile
(
method
=
"binary"
,
metrics
=
[
Binary
Accuracy
])
model
.
compile
(
method
=
"binary"
,
metrics
=
[
Accuracy
])
model
.
fit
(
X_wide
=
X_wide
,
X_deep
=
X_deep
,
...
...
pytorch_widedeep/metrics.py
浏览文件 @
65465a45
import
numpy
as
np
import
torch
from
.wdtypes
import
*
from
.callbacks
import
Callback
...
...
@@ -46,23 +47,17 @@ class MetricCallback(Callback):
self
.
container
.
reset
()
class
Categorical
Accuracy
(
Metric
):
r
"""Class to calculate the
categorical accuracy for multi
categorical problems
class
Accuracy
(
Metric
):
r
"""Class to calculate the
accuracy for both binary and
categorical problems
Parameters
----------
top_k: int
Accuracy will be computed using the top k most likely classes
Examples
--------
>>> y_true = torch.from_numpy(np.random.choice(3, 100))
>>> y_pred = torch.from_numpy(np.random.rand(100, 3))
>>> metric = CategoricalAccuracy(top_k=top_k)
>>> acc = metric(y_pred, y_true)
top_k: int, default = 1
Accuracy will be computed using the top k most likely classes in
multiclass problems
"""
def
__init__
(
self
,
top_k
=
1
):
def
__init__
(
self
,
top_k
:
int
=
1
):
self
.
top_k
=
top_k
self
.
correct_count
=
0
self
.
total_count
=
0
...
...
@@ -77,41 +72,179 @@ class CategoricalAccuracy(Metric):
self
.
total_count
=
0
def
__call__
(
self
,
y_pred
:
Tensor
,
y_true
:
Tensor
)
->
np
.
ndarray
:
top_k
=
y_pred
.
topk
(
self
.
top_k
,
1
)[
1
]
true_k
=
y_true
.
view
(
len
(
y_true
),
1
).
expand_as
(
top_k
)
# type: ignore
self
.
correct_count
+=
top_k
.
eq
(
true_k
).
float
().
sum
().
item
()
num_classes
=
y_pred
.
size
(
1
)
if
num_classes
==
1
:
y_pred
=
y_pred
.
round
()
y_true
=
y_true
.
view
(
-
1
,
1
)
elif
num_classes
>
1
:
y_pred
=
y_pred
.
topk
(
self
.
top_k
,
1
)[
1
]
y_true
=
y_true
.
view
(
-
1
,
1
).
expand_as
(
y_pred
)
# type: ignore
self
.
correct_count
+=
y_pred
.
eq
(
y_true
).
sum
().
item
()
# type: ignore
self
.
total_count
+=
len
(
y_pred
)
# type: ignore
accuracy
=
float
(
self
.
correct_count
)
/
float
(
self
.
total_count
)
return
np
.
round
(
accuracy
,
4
)
return
accuracy
class
BinaryAccuracy
(
Metric
):
"""Class to calculate accuracy for binary classification
problems
class
Precision
(
Metric
):
r
"""Class to calculate the precision for both binary and categorical
problems
Examples
--------
>>> y_true = torch.from_numpy(np.random.choice(2, 100)).float()
>>> y_pred = deepcopy(y_true.view(-1, 1)).float()
>>> metric = BinaryAccuracy()
>>> acc = metric(y_pred, y_true)
Parameters
----------
average: bool, default = True
This applies only to multiclass problems. if `True` calculate
precision for each label, and find their unweighted mean.
"""
def
__init__
(
self
):
self
.
correct_count
=
0
self
.
total_count
=
0
def
__init__
(
self
,
average
:
bool
=
True
):
self
.
average
=
average
self
.
true_positives
=
0
self
.
all_positives
=
0
self
.
eps
=
1e-20
self
.
_name
=
"
ac
c"
self
.
_name
=
"
pre
c"
def
reset
(
self
):
def
reset
(
self
)
->
None
:
"""
resets counters to 0
"""
self
.
correct_count
=
0
self
.
total_count
=
0
self
.
true_positives
=
0
self
.
all_positives
=
0
def
__call__
(
self
,
y_pred
:
Tensor
,
y_true
:
Tensor
)
->
np
.
ndarray
:
y_pred_round
=
y_pred
.
round
()
self
.
correct_count
+=
y_pred_round
.
eq
(
y_true
.
view
(
-
1
,
1
)).
float
().
sum
().
item
()
self
.
total_count
+=
len
(
y_pred
)
# type: ignore
accuracy
=
float
(
self
.
correct_count
)
/
float
(
self
.
total_count
)
return
np
.
round
(
accuracy
,
4
)
num_class
=
y_pred
.
size
(
1
)
if
num_class
==
1
:
y_pred
=
y_pred
.
round
()
y_true
=
y_true
.
view
(
-
1
,
1
)
elif
num_class
>
1
:
y_true
=
torch
.
eye
(
num_class
)[
y_true
.
long
()]
y_pred
=
y_pred
.
topk
(
1
,
1
)[
1
].
view
(
-
1
)
y_pred
=
torch
.
eye
(
num_class
)[
y_pred
.
long
()]
self
.
true_positives
+=
(
y_true
*
y_pred
).
sum
(
dim
=
0
)
# type:ignore
self
.
all_positives
+=
y_pred
.
sum
(
dim
=
0
)
# type:ignore
precision
=
self
.
true_positives
/
(
self
.
all_positives
+
self
.
eps
)
if
self
.
average
:
return
precision
.
mean
().
item
()
# type:ignore
else
:
return
precision
class
Recall
(
Metric
):
r
"""Class to calculate the recall for both binary and categorical problems
Parameters
----------
average: bool, default = True
This applies only to multiclass problems. if `True` calculate recall
for each label, and find their unweighted mean.
"""
def
__init__
(
self
,
average
:
bool
=
True
):
self
.
average
=
average
self
.
true_positives
=
0
self
.
actual_positives
=
0
self
.
eps
=
1e-20
self
.
_name
=
"rec"
def
reset
(
self
)
->
None
:
"""
resets counters to 0
"""
self
.
true_positives
=
0
self
.
actual_positives
=
0
def
__call__
(
self
,
y_pred
:
Tensor
,
y_true
:
Tensor
)
->
np
.
ndarray
:
num_class
=
y_pred
.
size
(
1
)
if
num_class
==
1
:
y_pred
=
y_pred
.
round
()
y_true
=
y_true
.
view
(
-
1
,
1
)
elif
num_class
>
1
:
y_true
=
torch
.
eye
(
num_class
)[
y_true
.
long
()]
y_pred
=
y_pred
.
topk
(
1
,
1
)[
1
].
view
(
-
1
)
y_pred
=
torch
.
eye
(
num_class
)[
y_pred
.
long
()]
self
.
true_positives
+=
(
y_true
*
y_pred
).
sum
(
dim
=
0
)
# type: ignore
self
.
actual_positives
+=
y_true
.
sum
(
dim
=
0
)
# type: ignore
recall
=
self
.
true_positives
/
(
self
.
actual_positives
+
self
.
eps
)
if
self
.
average
:
return
recall
.
mean
().
item
()
# type:ignore
else
:
return
recall
class
FBetaScore
(
Metric
):
r
"""Class to calculate the fbeta score for both binary and categorical problems
``FBeta = ((1 + beta^2) * Precision * Recall) / (beta^2 * Precision + Recall)``
Parameters
----------
beta: int
Coefficient to control the balance between precision and recall
average: bool, default = True
This applies only to multiclass problems. if `True` calculate fbeta
for each label, and find their unweighted mean.
"""
def
__init__
(
self
,
beta
:
int
,
average
:
bool
=
True
):
self
.
average
=
average
self
.
precision
=
Precision
(
average
=
False
)
self
.
recall
=
Recall
(
average
=
False
)
self
.
beta
=
beta
self
.
_name
=
""
.
join
([
"f"
,
str
(
beta
)])
def
reset
(
self
)
->
None
:
"""
resets precision and recall
"""
self
.
precision
.
reset
()
self
.
recall
.
reset
()
def
__call__
(
self
,
y_pred
:
Tensor
,
y_true
:
Tensor
)
->
np
.
ndarray
:
prec
=
self
.
precision
(
y_pred
,
y_true
)
rec
=
self
.
recall
(
y_pred
,
y_true
)
beta2
=
self
.
beta
**
2
fbeta
=
((
1
+
beta2
)
*
prec
*
rec
)
/
(
beta2
*
prec
+
rec
)
if
self
.
average
:
return
fbeta
.
mean
().
item
()
else
:
return
fbeta
class
F1Score
(
Metric
):
r
"""Class to calculate the f1 score for both binary and categorical problems
Parameters
----------
average: bool, default = True
This applies only to multiclass problems. if `True` calculate f1 for
each label, and find their unweighted mean.
"""
def
__init__
(
self
,
average
:
bool
=
True
):
self
.
f1
=
FBetaScore
(
beta
=
1
,
average
=
average
)
self
.
_name
=
self
.
f1
.
_name
def
reset
(
self
)
->
None
:
"""
resets counters to 0
"""
self
.
f1
.
reset
()
def
__call__
(
self
,
y_pred
:
Tensor
,
y_true
:
Tensor
)
->
np
.
ndarray
:
return
self
.
f1
(
y_pred
,
y_true
)
pytorch_widedeep/models/wide_deep.py
浏览文件 @
65465a45
...
...
@@ -266,9 +266,9 @@ class WideDeep(nn.Module):
See the ``Callbacks`` section in this documentation or
:obj:`pytorch_widedeep.callbacks`
metrics: List[Metric], Optional. Default=None
Metrics available are: ``
BinaryAccuracy`` and
``
CategoricalAccuracy`` See the ``Metrics`` section in this
documentation or :obj:`pytorch_widedeep.metrics`
Metrics available are: ``
Accuracy``, ``Precision``, ``Recall``,
``
FBetaScore`` and ``F1Score``. See the ``Metrics`` section in
this
documentation or :obj:`pytorch_widedeep.metrics`
class_weight: Union[float, List[float], Tuple[float]]. Optional. Default=None
- float indicating the weight of the minority class in binary classification
problems (e.g. 9.)
...
...
@@ -587,17 +587,22 @@ class WideDeep(nn.Module):
with
trange
(
train_steps
,
disable
=
self
.
verbose
!=
1
)
as
t
:
for
batch_idx
,
(
data
,
target
)
in
zip
(
t
,
train_loader
):
t
.
set_description
(
"epoch %i"
%
(
epoch
+
1
))
acc
,
train_loss
=
self
.
_training_step
(
data
,
target
,
batch_idx
)
if
acc
is
not
None
:
t
.
set_postfix
(
metrics
=
acc
,
loss
=
train_loss
)
score
,
train_loss
=
self
.
_training_step
(
data
,
target
,
batch_idx
)
if
score
is
not
None
:
t
.
set_postfix
(
metrics
=
{
k
:
np
.
round
(
v
,
4
)
for
k
,
v
in
score
.
items
()},
loss
=
train_loss
,
)
else
:
t
.
set_postfix
(
loss
=
np
.
sqrt
(
train_loss
))
if
self
.
lr_scheduler
:
self
.
_lr_scheduler_step
(
step_location
=
"on_batch_end"
)
self
.
callback_container
.
on_batch_end
(
batch
=
batch_idx
)
epoch_logs
[
"train_loss"
]
=
train_loss
if
acc
is
not
None
:
epoch_logs
[
"train_acc"
]
=
acc
[
"acc"
]
if
score
is
not
None
:
for
k
,
v
in
score
.
items
():
log_k
=
"_"
.
join
([
"train"
,
k
])
epoch_logs
[
log_k
]
=
v
# eval step...
if
epoch
%
validation_freq
==
(
validation_freq
-
1
):
if
eval_set
is
not
None
:
...
...
@@ -612,14 +617,21 @@ class WideDeep(nn.Module):
with
trange
(
eval_steps
,
disable
=
self
.
verbose
!=
1
)
as
v
:
for
i
,
(
data
,
target
)
in
zip
(
v
,
eval_loader
):
v
.
set_description
(
"valid"
)
acc
,
val_loss
=
self
.
_validation_step
(
data
,
target
,
i
)
if
acc
is
not
None
:
v
.
set_postfix
(
metrics
=
acc
,
loss
=
val_loss
)
score
,
val_loss
=
self
.
_validation_step
(
data
,
target
,
i
)
if
score
is
not
None
:
v
.
set_postfix
(
metrics
=
{
k
:
np
.
round
(
v
,
4
)
for
k
,
v
in
score
.
items
()
},
loss
=
val_loss
,
)
else
:
v
.
set_postfix
(
loss
=
np
.
sqrt
(
val_loss
))
epoch_logs
[
"val_loss"
]
=
val_loss
if
acc
is
not
None
:
epoch_logs
[
"val_acc"
]
=
acc
[
"acc"
]
if
score
is
not
None
:
for
k
,
v
in
score
.
items
():
log_k
=
"_"
.
join
([
"val"
,
k
])
epoch_logs
[
log_k
]
=
v
if
self
.
lr_scheduler
:
self
.
_lr_scheduler_step
(
step_location
=
"on_epoch_end"
)
# log and check if early_stop...
...
...
@@ -986,10 +998,10 @@ class WideDeep(nn.Module):
if
self
.
metric
is
not
None
:
if
self
.
method
==
"binary"
:
acc
=
self
.
metric
(
torch
.
sigmoid
(
y_pred
),
y
)
score
=
self
.
metric
(
torch
.
sigmoid
(
y_pred
),
y
)
if
self
.
method
==
"multiclass"
:
acc
=
self
.
metric
(
F
.
softmax
(
y_pred
,
dim
=
1
),
y
)
return
acc
,
avg_loss
score
=
self
.
metric
(
F
.
softmax
(
y_pred
,
dim
=
1
),
y
)
return
score
,
avg_loss
else
:
return
None
,
avg_loss
...
...
@@ -1008,10 +1020,10 @@ class WideDeep(nn.Module):
if
self
.
metric
is
not
None
:
if
self
.
method
==
"binary"
:
acc
=
self
.
metric
(
torch
.
sigmoid
(
y_pred
),
y
)
score
=
self
.
metric
(
torch
.
sigmoid
(
y_pred
),
y
)
if
self
.
method
==
"multiclass"
:
acc
=
self
.
metric
(
F
.
softmax
(
y_pred
,
dim
=
1
),
y
)
return
acc
,
avg_loss
score
=
self
.
metric
(
F
.
softmax
(
y_pred
,
dim
=
1
),
y
)
return
score
,
avg_loss
else
:
return
None
,
avg_loss
...
...
pytorch_widedeep/version.py
浏览文件 @
65465a45
__version__
=
"0.4.
1
"
__version__
=
"0.4.
2
"
tests/test_model_functioning/test_metrics.py
浏览文件 @
65465a45
...
...
@@ -3,23 +3,113 @@ from copy import deepcopy
import
numpy
as
np
import
torch
import
pytest
from
sklearn.metrics
import
(
f1_score
,
fbeta_score
,
recall_score
,
accuracy_score
,
precision_score
,
)
from
pytorch_widedeep.metrics
import
BinaryAccuracy
,
CategoricalAccuracy
from
pytorch_widedeep.metrics
import
(
Recall
,
F1Score
,
Accuracy
,
Precision
,
FBetaScore
,
)
y_true
=
torch
.
from_numpy
(
np
.
random
.
choice
(
2
,
100
)).
float
()
y_pred
=
deepcopy
(
y_true
.
view
(
-
1
,
1
)).
float
()
def
f2_score_bin
(
y_true
,
y_pred
):
return
fbeta_score
(
y_true
,
y_pred
,
beta
=
2
)
def
test_binary_accuracy
():
metric
=
BinaryAccuracy
()
acc
=
metric
(
y_pred
,
y_true
)
assert
acc
==
1.0
y_true_bin_np
=
np
.
array
([
1
,
0
,
0
,
0
,
1
,
1
,
0
])
y_pred_bin_np
=
np
.
array
([
0.6
,
0.3
,
0.2
,
0.8
,
0.4
,
0.9
,
0.6
])
y_true_bin_pt
=
torch
.
from_numpy
(
y_true_bin_np
)
y_pred_bin_pt
=
torch
.
from_numpy
(
y_pred_bin_np
).
view
(
-
1
,
1
)
###############################################################################
# Test binary metrics
###############################################################################
@
pytest
.
mark
.
parametrize
(
"sklearn_metric, widedeep_metric"
,
[
(
accuracy_score
,
Accuracy
()),
(
precision_score
,
Precision
()),
(
recall_score
,
Recall
()),
(
f1_score
,
F1Score
()),
(
f2_score_bin
,
FBetaScore
(
beta
=
2
)),
],
)
def
test_binary_metrics
(
sklearn_metric
,
widedeep_metric
):
assert
np
.
isclose
(
sklearn_metric
(
y_true_bin_np
,
y_pred_bin_np
.
round
()),
widedeep_metric
(
y_pred_bin_pt
,
y_true_bin_pt
),
)
###############################################################################
# Test top_k for Accuracy
###############################################################################
@
pytest
.
mark
.
parametrize
(
"top_k, expected_acc"
,
[(
1
,
0.33
),
(
2
,
0.66
)])
def
test_categorical_accuracy
(
top_k
,
expected_acc
):
def
test_categorical_accuracy
_topk
(
top_k
,
expected_acc
):
y_true
=
torch
.
from_numpy
(
np
.
random
.
choice
(
3
,
100
))
y_pred
=
torch
.
from_numpy
(
np
.
random
.
rand
(
100
,
3
))
metric
=
Categorical
Accuracy
(
top_k
=
top_k
)
metric
=
Accuracy
(
top_k
=
top_k
)
acc
=
metric
(
y_pred
,
y_true
)
assert
np
.
isclose
(
acc
,
expected_acc
,
atol
=
0.3
)
###############################################################################
# Test multiclass metrics
###############################################################################
y_true_multi_np
=
np
.
array
([
1
,
0
,
2
,
1
,
1
,
2
,
2
,
0
,
0
,
0
])
y_pred_muli_np
=
np
.
array
(
[
[
0.2
,
0.6
,
0.2
],
[
0.4
,
0.5
,
0.1
],
[
0.1
,
0.1
,
0.8
],
[
0.1
,
0.6
,
0.3
],
[
0.1
,
0.8
,
0.1
],
[
0.1
,
0.6
,
0.6
],
[
0.2
,
0.6
,
0.8
],
[
0.6
,
0.1
,
0.3
],
[
0.7
,
0.2
,
0.1
],
[
0.1
,
0.7
,
0.2
],
]
)
y_true_multi_pt
=
torch
.
from_numpy
(
y_true_multi_np
)
y_pred_multi_pt
=
torch
.
from_numpy
(
y_pred_muli_np
)
def
f2_score_multi
(
y_true
,
y_pred
,
average
):
return
fbeta_score
(
y_true
,
y_pred
,
average
=
average
,
beta
=
2
)
@
pytest
.
mark
.
parametrize
(
"sklearn_metric, widedeep_metric"
,
[
(
accuracy_score
,
Accuracy
()),
(
precision_score
,
Precision
()),
(
recall_score
,
Recall
()),
(
f1_score
,
F1Score
()),
(
f2_score_multi
,
FBetaScore
(
beta
=
2
)),
],
)
def
test_muticlass_metrics
(
sklearn_metric
,
widedeep_metric
):
if
sklearn_metric
.
__name__
==
"accuracy_score"
:
assert
np
.
isclose
(
sklearn_metric
(
y_true_multi_np
,
y_pred_muli_np
.
argmax
(
axis
=
1
)),
widedeep_metric
(
y_pred_multi_pt
,
y_true_multi_pt
),
)
else
:
assert
np
.
isclose
(
sklearn_metric
(
y_true_multi_np
,
y_pred_muli_np
.
argmax
(
axis
=
1
),
average
=
"macro"
),
widedeep_metric
(
y_pred_multi_pt
,
y_true_multi_pt
),
)
tests/test_warm_up/test_warm_up_routines.py
浏览文件 @
65465a45
...
...
@@ -9,7 +9,7 @@ from sklearn.utils import Bunch
from
torch.utils.data
import
Dataset
,
DataLoader
from
pytorch_widedeep.models
import
Wide
,
DeepDense
from
pytorch_widedeep.metrics
import
Binary
Accuracy
from
pytorch_widedeep.metrics
import
Accuracy
from
pytorch_widedeep.models._warmup
import
WarmUp
from
pytorch_widedeep.models.deep_image
import
conv_layer
...
...
@@ -138,7 +138,7 @@ wdset = WDset(X_wide, X_deep, X_text, X_image, target)
wdloader
=
DataLoader
(
wdset
,
batch_size
=
10
,
shuffle
=
True
)
# Instantiate the WarmUp class
warmer
=
WarmUp
(
loss_fn
,
Binary
Accuracy
(),
"binary"
,
False
)
warmer
=
WarmUp
(
loss_fn
,
Accuracy
(),
"binary"
,
False
)
# List the layers for the warm_gradual method
text_layers
=
[
c
for
c
in
list
(
deeptext
.
children
())[
1
:]][::
-
1
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录