Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Pytorch Widedeep
提交
d13e86f9
P
Pytorch Widedeep
项目概览
Greenplum
/
Pytorch Widedeep
11 个月 前同步成功
通知
9
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Pytorch Widedeep
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d13e86f9
编写于
12月 11, 2021
作者:
P
Pavol Mulinka
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fixed multiclass torchmetrics
上级
ef9ea277
变更
4
展开全部
隐藏空白更改
内联
并排
Showing
4 changed file
with
127 addition
and
53 deletion
+127
-53
examples/15_AUC_multiclass.ipynb
examples/15_AUC_multiclass.ipynb
+37
-47
pytorch_widedeep/metrics.py
pytorch_widedeep/metrics.py
+75
-2
pytorch_widedeep/training/trainer.py
pytorch_widedeep/training/trainer.py
+7
-3
tests/test_metrics/test_torchmetrics.py
tests/test_metrics/test_torchmetrics.py
+8
-1
未找到文件。
examples/15_AUC_multiclass.ipynb
浏览文件 @
d13e86f9
此差异已折叠。
点击以展开。
pytorch_widedeep/metrics.py
浏览文件 @
d13e86f9
import
numpy
as
np
import
torch
from
torchmetrics
import
Metric
as
TorchMetric
from
torchmetrics
import
AUC
from
.wdtypes
import
*
# noqa: F403
...
...
@@ -38,10 +39,23 @@ class MultipleMetrics(object):
if
isinstance
(
metric
,
Metric
):
logs
[
self
.
prefix
+
metric
.
_name
]
=
metric
(
y_pred
,
y_true
)
if
isinstance
(
metric
,
TorchMetric
):
if
not
hasattr
(
metric
,
"num_classes"
):
raise
ValueError
(
"""TorchMetric does not have num_classes attribute.
Use metric in this library or extend the metric by num_classes attribute,
see `examples <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`
"""
)
if
metric
.
num_classes
==
2
:
metric
.
update
(
torch
.
round
(
y_pred
).
int
(),
y_true
.
int
())
if
isinstance
(
metric
,
AUC
):
metric
.
update
(
torch
.
round
(
y_pred
).
int
(),
y_true
.
int
())
else
:
metric
.
update
(
y_pred
,
y_true
.
int
())
if
metric
.
num_classes
>
2
:
# type: ignore[operator]
metric
.
update
(
torch
.
max
(
y_pred
,
dim
=
1
).
indices
,
y_true
.
int
())
# type: ignore[attr-defined]
if
isinstance
(
metric
,
AUC
):
metric
.
update
(
torch
.
max
(
y_pred
,
dim
=
1
).
indices
,
y_true
.
int
())
# type: ignore[attr-defined]
else
:
metric
.
update
(
y_pred
,
y_true
.
int
())
# type: ignore[attr-defined]
logs
[
self
.
prefix
+
type
(
metric
).
__name__
]
=
(
metric
.
compute
().
detach
().
cpu
().
numpy
()
)
...
...
@@ -396,3 +410,62 @@ class R2Score(Metric):
y_true_avg
=
self
.
y_true_sum
/
self
.
num_examples
self
.
denominator
+=
((
y_true
-
y_true_avg
)
**
2
).
sum
().
item
()
return
np
.
array
((
1
-
(
self
.
numerator
/
self
.
denominator
)))
class
Accuracy
(
Metric
):
r
"""Class to calculate the accuracy for both binary and categorical problems
Parameters
----------
top_k: int, default = 1
Accuracy will be computed using the top k most likely classes in
multiclass problems
Examples
--------
>>> import torch
>>>
>>> from pytorch_widedeep.metrics import Accuracy
>>>
>>> acc = Accuracy()
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> acc(y_pred, y_true)
array(0.5)
>>>
>>> acc = Accuracy(top_k=2)
>>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> acc(y_pred, y_true)
array(0.66666667)
"""
def
__init__
(
self
,
top_k
:
int
=
1
):
super
(
Accuracy
,
self
).
__init__
()
self
.
top_k
=
top_k
self
.
correct_count
=
0
self
.
total_count
=
0
self
.
_name
=
"acc"
def
reset
(
self
):
"""
resets counters to 0
"""
self
.
correct_count
=
0
self
.
total_count
=
0
def
__call__
(
self
,
y_pred
:
Tensor
,
y_true
:
Tensor
)
->
np
.
ndarray
:
num_classes
=
y_pred
.
size
(
1
)
if
num_classes
==
1
:
y_pred
=
y_pred
.
round
()
y_true
=
y_true
elif
num_classes
>
1
:
y_pred
=
y_pred
.
topk
(
self
.
top_k
,
1
)[
1
]
y_true
=
y_true
.
view
(
-
1
,
1
).
expand_as
(
y_pred
)
self
.
correct_count
+=
y_pred
.
eq
(
y_true
).
sum
().
item
()
# type: ignore[assignment]
self
.
total_count
+=
len
(
y_pred
)
accuracy
=
float
(
self
.
correct_count
)
/
float
(
self
.
total_count
)
return
np
.
array
(
accuracy
)
pytorch_widedeep/training/trainer.py
浏览文件 @
d13e86f9
...
...
@@ -147,10 +147,14 @@ class Trainer:
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
folder in the repo
- List of objects of type :obj:`torchmetrics.Metric`. This can be any
metric from torchmetrics library `Examples
metric from torchmetrics library
that has attribute num_classes
`Examples
<https://torchmetrics.readthedocs.io/en/latest/references/modules.html#
classification-metrics>`_. This can also be a custom metric as
long as it is an object of type :obj:`Metric`. See `the instructions
classification-metrics>`_.
Objects of type :obj:`torchmetrics.Metric` can be extended with num_classes
attribute to be used with the Trainer object, see `examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`.
This can also be a custom metric as long as it is an object of
type :obj:`Metric`. See `the instructions
<https://torchmetrics.readthedocs.io/en/latest/>`_.
class_weight: float, List or Tuple. optional. default=None
- float indicating the weight of the minority class in binary classification
...
...
tests/test_metrics/test_torchmetrics.py
浏览文件 @
d13e86f9
import
numpy
as
np
import
torch
import
pytest
from
torchmetrics
import
F1
,
FBeta
,
Recall
,
Accuracy
,
Precision
from
torchmetrics
import
F1
,
FBeta
,
Recall
,
Accuracy
,
Precision
,
AUC
from
sklearn.metrics
import
(
f1_score
,
fbeta_score
,
recall_score
,
accuracy_score
,
precision_score
,
auc_score
,
)
from
pytorch_widedeep.metrics
import
MultipleMetrics
...
...
@@ -35,9 +36,12 @@ y_pred_bin_pt = torch.from_numpy(y_pred_bin_np)
(
"Recall"
,
recall_score
,
Recall
(
num_classes
=
2
,
average
=
"none"
)),
(
"F1"
,
f1_score
,
F1
(
num_classes
=
2
,
average
=
"none"
)),
(
"FBeta"
,
f2_score_bin
,
FBeta
(
beta
=
2
,
num_classes
=
2
,
average
=
"none"
)),
(
"AUC"
,
auc_score
,
AUC
()),
],
)
def
test_binary_metrics
(
metric_name
,
sklearn_metric
,
torch_metric
):
if
metric_name
==
"AUC"
:
torch_metric
.
num_classes
=
2
sk_res
=
sklearn_metric
(
y_true_bin_np
,
y_pred_bin_np
.
round
())
wd_metric
=
MultipleMetrics
(
metrics
=
[
torch_metric
])
wd_logs
=
wd_metric
(
y_pred_bin_pt
,
y_true_bin_pt
)
...
...
@@ -82,11 +86,14 @@ def f2_score_multi(y_true, y_pred, average):
(
"Recall"
,
recall_score
,
Recall
(
num_classes
=
3
,
average
=
"macro"
)),
(
"F1"
,
f1_score
,
F1
(
num_classes
=
3
,
average
=
"macro"
)),
(
"FBeta"
,
f2_score_multi
,
FBeta
(
beta
=
3
,
num_classes
=
3
,
average
=
"macro"
)),
(
"AUC"
,
auc_score
,
AUC
()),
],
)
def
test_muticlass_metrics
(
metric_name
,
sklearn_metric
,
torch_metric
):
if
metric_name
==
"Accuracy"
:
sk_res
=
sklearn_metric
(
y_true_multi_np
,
y_pred_muli_np
.
argmax
(
axis
=
1
))
elif
metric_name
==
"AUC"
:
torch_metric
.
num_classes
=
3
else
:
sk_res
=
sklearn_metric
(
y_true_multi_np
,
y_pred_muli_np
.
argmax
(
axis
=
1
),
average
=
"macro"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录