提交 707e01ae 编写于 作者: C cuicheng01

Update GoogLeNetLoss

上级 4e154aed
......@@ -122,8 +122,8 @@ Infer:
Metric:
Train:
- TopkAcc:
- GoogLeNetTopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
- GoogLeNetTopkAcc:
topk: [1, 5]
......@@ -18,6 +18,7 @@ from collections import OrderedDict
from .metrics import TopkAcc, mAP, mINP, Recallk
from .metrics import DistillationTopkAcc
from .metrics import GoogLeNetTopkAcc
class CombinedMetrics(nn.Layer):
def __init__(self, config_list):
......
......@@ -25,8 +25,6 @@ class TopkAcc(nn.Layer):
self.topk = topk
def forward(self, x, label):
if isinstance(x, list):
x = x[0]
if isinstance(x, dict):
x = x["logits"]
......@@ -122,3 +120,16 @@ class DistillationTopkAcc(TopkAcc):
if self.feature_key is not None:
x = x[self.feature_key]
return super().forward(x, label)
class GoogLeNetTopkAcc(TopkAcc):
def __init__(self, topk=(1, 5)):
super().__init__()
assert isinstance(topk, (int, list, tuple))
if isinstance(topk, int):
topk = [topk]
self.topk = topk
def forward(self, x, label):
return super().forward(x[0], label)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册