Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Pytorch Widedeep
提交
d13e86f9
P
Pytorch Widedeep
项目概览
Greenplum
/
Pytorch Widedeep
11 个月 前同步成功
通知
9
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Pytorch Widedeep
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d13e86f9
编写于
12月 11, 2021
作者:
P
Pavol Mulinka
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fixed multiclass torchmetrics
上级
ef9ea277
变更
4
展开全部
隐藏空白更改
内联
并排
Showing
4 changed file
with
127 addition
and
53 deletion
+127
-53
examples/15_AUC_multiclass.ipynb
examples/15_AUC_multiclass.ipynb
+37
-47
pytorch_widedeep/metrics.py
pytorch_widedeep/metrics.py
+75
-2
pytorch_widedeep/training/trainer.py
pytorch_widedeep/training/trainer.py
+7
-3
tests/test_metrics/test_torchmetrics.py
tests/test_metrics/test_torchmetrics.py
+8
-1
未找到文件。
examples/15_AUC_multiclass.ipynb
浏览文件 @
d13e86f9
此差异已折叠。
点击以展开。
pytorch_widedeep/metrics.py
浏览文件 @
d13e86f9
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torchmetrics
import
Metric
as
TorchMetric
from
torchmetrics
import
Metric
as
TorchMetric
from
torchmetrics
import
AUC
from
.wdtypes
import
*
# noqa: F403
from
.wdtypes
import
*
# noqa: F403
...
@@ -38,10 +39,23 @@ class MultipleMetrics(object):
...
@@ -38,10 +39,23 @@ class MultipleMetrics(object):
if
isinstance
(
metric
,
Metric
):
if
isinstance
(
metric
,
Metric
):
logs
[
self
.
prefix
+
metric
.
_name
]
=
metric
(
y_pred
,
y_true
)
logs
[
self
.
prefix
+
metric
.
_name
]
=
metric
(
y_pred
,
y_true
)
if
isinstance
(
metric
,
TorchMetric
):
if
isinstance
(
metric
,
TorchMetric
):
if
not
hasattr
(
metric
,
"num_classes"
):
raise
ValueError
(
"""TorchMetric does not have num_classes attribute.
Use metric in this library or extend the metric by num_classes attribute,
see `examples <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`
"""
)
if
metric
.
num_classes
==
2
:
if
metric
.
num_classes
==
2
:
metric
.
update
(
torch
.
round
(
y_pred
).
int
(),
y_true
.
int
())
if
isinstance
(
metric
,
AUC
):
metric
.
update
(
torch
.
round
(
y_pred
).
int
(),
y_true
.
int
())
else
:
metric
.
update
(
y_pred
,
y_true
.
int
())
if
metric
.
num_classes
>
2
:
# type: ignore[operator]
if
metric
.
num_classes
>
2
:
# type: ignore[operator]
metric
.
update
(
torch
.
max
(
y_pred
,
dim
=
1
).
indices
,
y_true
.
int
())
# type: ignore[attr-defined]
if
isinstance
(
metric
,
AUC
):
metric
.
update
(
torch
.
max
(
y_pred
,
dim
=
1
).
indices
,
y_true
.
int
())
# type: ignore[attr-defined]
else
:
metric
.
update
(
y_pred
,
y_true
.
int
())
# type: ignore[attr-defined]
logs
[
self
.
prefix
+
type
(
metric
).
__name__
]
=
(
logs
[
self
.
prefix
+
type
(
metric
).
__name__
]
=
(
metric
.
compute
().
detach
().
cpu
().
numpy
()
metric
.
compute
().
detach
().
cpu
().
numpy
()
)
)
...
@@ -396,3 +410,62 @@ class R2Score(Metric):
...
@@ -396,3 +410,62 @@ class R2Score(Metric):
y_true_avg
=
self
.
y_true_sum
/
self
.
num_examples
y_true_avg
=
self
.
y_true_sum
/
self
.
num_examples
self
.
denominator
+=
((
y_true
-
y_true_avg
)
**
2
).
sum
().
item
()
self
.
denominator
+=
((
y_true
-
y_true_avg
)
**
2
).
sum
().
item
()
return
np
.
array
((
1
-
(
self
.
numerator
/
self
.
denominator
)))
return
np
.
array
((
1
-
(
self
.
numerator
/
self
.
denominator
)))
class
Accuracy
(
Metric
):
r
"""Class to calculate the accuracy for both binary and categorical problems
Parameters
----------
top_k: int, default = 1
Accuracy will be computed using the top k most likely classes in
multiclass problems
Examples
--------
>>> import torch
>>>
>>> from pytorch_widedeep.metrics import Accuracy
>>>
>>> acc = Accuracy()
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> acc(y_pred, y_true)
array(0.5)
>>>
>>> acc = Accuracy(top_k=2)
>>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> acc(y_pred, y_true)
array(0.66666667)
"""
def
__init__
(
self
,
top_k
:
int
=
1
):
super
(
Accuracy
,
self
).
__init__
()
self
.
top_k
=
top_k
self
.
correct_count
=
0
self
.
total_count
=
0
self
.
_name
=
"acc"
def
reset
(
self
):
"""
resets counters to 0
"""
self
.
correct_count
=
0
self
.
total_count
=
0
def
__call__
(
self
,
y_pred
:
Tensor
,
y_true
:
Tensor
)
->
np
.
ndarray
:
num_classes
=
y_pred
.
size
(
1
)
if
num_classes
==
1
:
y_pred
=
y_pred
.
round
()
y_true
=
y_true
elif
num_classes
>
1
:
y_pred
=
y_pred
.
topk
(
self
.
top_k
,
1
)[
1
]
y_true
=
y_true
.
view
(
-
1
,
1
).
expand_as
(
y_pred
)
self
.
correct_count
+=
y_pred
.
eq
(
y_true
).
sum
().
item
()
# type: ignore[assignment]
self
.
total_count
+=
len
(
y_pred
)
accuracy
=
float
(
self
.
correct_count
)
/
float
(
self
.
total_count
)
return
np
.
array
(
accuracy
)
pytorch_widedeep/training/trainer.py
浏览文件 @
d13e86f9
...
@@ -147,10 +147,14 @@ class Trainer:
...
@@ -147,10 +147,14 @@ class Trainer:
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
folder in the repo
folder in the repo
- List of objects of type :obj:`torchmetrics.Metric`. This can be any
- List of objects of type :obj:`torchmetrics.Metric`. This can be any
metric from torchmetrics library `Examples
metric from torchmetrics library
that has attribute num_classes
`Examples
<https://torchmetrics.readthedocs.io/en/latest/references/modules.html#
<https://torchmetrics.readthedocs.io/en/latest/references/modules.html#
classification-metrics>`_. This can also be a custom metric as
classification-metrics>`_.
long as it is an object of type :obj:`Metric`. See `the instructions
Objects of type :obj:`torchmetrics.Metric` can be extended with num_classes
attribute to be used with the Trainer object, see `examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`.
This can also be a custom metric as long as it is an object of
type :obj:`Metric`. See `the instructions
<https://torchmetrics.readthedocs.io/en/latest/>`_.
<https://torchmetrics.readthedocs.io/en/latest/>`_.
class_weight: float, List or Tuple. optional. default=None
class_weight: float, List or Tuple. optional. default=None
- float indicating the weight of the minority class in binary classification
- float indicating the weight of the minority class in binary classification
...
...
tests/test_metrics/test_torchmetrics.py
浏览文件 @
d13e86f9
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
pytest
import
pytest
from
torchmetrics
import
F1
,
FBeta
,
Recall
,
Accuracy
,
Precision
from
torchmetrics
import
F1
,
FBeta
,
Recall
,
Accuracy
,
Precision
,
AUC
from
sklearn.metrics
import
(
from
sklearn.metrics
import
(
f1_score
,
f1_score
,
fbeta_score
,
fbeta_score
,
recall_score
,
recall_score
,
accuracy_score
,
accuracy_score
,
precision_score
,
precision_score
,
auc_score
,
)
)
from
pytorch_widedeep.metrics
import
MultipleMetrics
from
pytorch_widedeep.metrics
import
MultipleMetrics
...
@@ -35,9 +36,12 @@ y_pred_bin_pt = torch.from_numpy(y_pred_bin_np)
...
@@ -35,9 +36,12 @@ y_pred_bin_pt = torch.from_numpy(y_pred_bin_np)
(
"Recall"
,
recall_score
,
Recall
(
num_classes
=
2
,
average
=
"none"
)),
(
"Recall"
,
recall_score
,
Recall
(
num_classes
=
2
,
average
=
"none"
)),
(
"F1"
,
f1_score
,
F1
(
num_classes
=
2
,
average
=
"none"
)),
(
"F1"
,
f1_score
,
F1
(
num_classes
=
2
,
average
=
"none"
)),
(
"FBeta"
,
f2_score_bin
,
FBeta
(
beta
=
2
,
num_classes
=
2
,
average
=
"none"
)),
(
"FBeta"
,
f2_score_bin
,
FBeta
(
beta
=
2
,
num_classes
=
2
,
average
=
"none"
)),
(
"AUC"
,
auc_score
,
AUC
()),
],
],
)
)
def
test_binary_metrics
(
metric_name
,
sklearn_metric
,
torch_metric
):
def
test_binary_metrics
(
metric_name
,
sklearn_metric
,
torch_metric
):
if
metric_name
==
"AUC"
:
torch_metric
.
num_classes
=
2
sk_res
=
sklearn_metric
(
y_true_bin_np
,
y_pred_bin_np
.
round
())
sk_res
=
sklearn_metric
(
y_true_bin_np
,
y_pred_bin_np
.
round
())
wd_metric
=
MultipleMetrics
(
metrics
=
[
torch_metric
])
wd_metric
=
MultipleMetrics
(
metrics
=
[
torch_metric
])
wd_logs
=
wd_metric
(
y_pred_bin_pt
,
y_true_bin_pt
)
wd_logs
=
wd_metric
(
y_pred_bin_pt
,
y_true_bin_pt
)
...
@@ -82,11 +86,14 @@ def f2_score_multi(y_true, y_pred, average):
...
@@ -82,11 +86,14 @@ def f2_score_multi(y_true, y_pred, average):
(
"Recall"
,
recall_score
,
Recall
(
num_classes
=
3
,
average
=
"macro"
)),
(
"Recall"
,
recall_score
,
Recall
(
num_classes
=
3
,
average
=
"macro"
)),
(
"F1"
,
f1_score
,
F1
(
num_classes
=
3
,
average
=
"macro"
)),
(
"F1"
,
f1_score
,
F1
(
num_classes
=
3
,
average
=
"macro"
)),
(
"FBeta"
,
f2_score_multi
,
FBeta
(
beta
=
3
,
num_classes
=
3
,
average
=
"macro"
)),
(
"FBeta"
,
f2_score_multi
,
FBeta
(
beta
=
3
,
num_classes
=
3
,
average
=
"macro"
)),
(
"AUC"
,
auc_score
,
AUC
()),
],
],
)
)
def
test_muticlass_metrics
(
metric_name
,
sklearn_metric
,
torch_metric
):
def
test_muticlass_metrics
(
metric_name
,
sklearn_metric
,
torch_metric
):
if
metric_name
==
"Accuracy"
:
if
metric_name
==
"Accuracy"
:
sk_res
=
sklearn_metric
(
y_true_multi_np
,
y_pred_muli_np
.
argmax
(
axis
=
1
))
sk_res
=
sklearn_metric
(
y_true_multi_np
,
y_pred_muli_np
.
argmax
(
axis
=
1
))
elif
metric_name
==
"AUC"
:
torch_metric
.
num_classes
=
3
else
:
else
:
sk_res
=
sklearn_metric
(
sk_res
=
sklearn_metric
(
y_true_multi_np
,
y_pred_muli_np
.
argmax
(
axis
=
1
),
average
=
"macro"
y_true_multi_np
,
y_pred_muli_np
.
argmax
(
axis
=
1
),
average
=
"macro"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录