提交 707e01ae 编写于 作者: C cuicheng01

Update GoogLeNetLoss

上级 4e154aed
...@@ -122,8 +122,8 @@ Infer: ...@@ -122,8 +122,8 @@ Infer:
Metric: Metric:
Train: Train:
- TopkAcc: - GoogLeNetTopkAcc:
topk: [1, 5] topk: [1, 5]
Eval: Eval:
- TopkAcc: - GoogLeNetTopkAcc:
topk: [1, 5] topk: [1, 5]
...@@ -18,6 +18,7 @@ from collections import OrderedDict ...@@ -18,6 +18,7 @@ from collections import OrderedDict
from .metrics import TopkAcc, mAP, mINP, Recallk from .metrics import TopkAcc, mAP, mINP, Recallk
from .metrics import DistillationTopkAcc from .metrics import DistillationTopkAcc
from .metrics import GoogLeNetTopkAcc
class CombinedMetrics(nn.Layer): class CombinedMetrics(nn.Layer):
def __init__(self, config_list): def __init__(self, config_list):
......
...@@ -25,8 +25,6 @@ class TopkAcc(nn.Layer): ...@@ -25,8 +25,6 @@ class TopkAcc(nn.Layer):
self.topk = topk self.topk = topk
def forward(self, x, label): def forward(self, x, label):
if isinstance(x, list):
x = x[0]
if isinstance(x, dict): if isinstance(x, dict):
x = x["logits"] x = x["logits"]
...@@ -122,3 +120,16 @@ class DistillationTopkAcc(TopkAcc): ...@@ -122,3 +120,16 @@ class DistillationTopkAcc(TopkAcc):
if self.feature_key is not None: if self.feature_key is not None:
x = x[self.feature_key] x = x[self.feature_key]
return super().forward(x, label) return super().forward(x, label)
class GoogLeNetTopkAcc(TopkAcc):
def __init__(self, topk=(1, 5)):
super().__init__()
assert isinstance(topk, (int, list, tuple))
if isinstance(topk, int):
topk = [topk]
self.topk = topk
def forward(self, x, label):
return super().forward(x[0], label)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册