提交 87a0ba6f 编写于 作者: C cuicheng01

update metrics

上级 93e9970e
......@@ -39,6 +39,7 @@ class CombinedMetrics(AvgMetrics):
eval(metric_name)(**metric_params))
else:
self.metric_func_list.append(eval(metric_name)())
self.reset()
def forward(self, *args, **kwargs):
metric_dict = OrderedDict()
......@@ -54,6 +55,10 @@ class CombinedMetrics(AvgMetrics):
def avg(self):
return self.metric_func_list[0].avg
def reset(self):
for metric in self.metric_func_list:
if hasattr(metric, "reset"):
metric.reset()
def build_metrics(config):
metrics_list = CombinedMetrics(copy.deepcopy(config))
......
......@@ -33,6 +33,9 @@ class TopkAcc(AvgMetrics):
if isinstance(topk, int):
topk = [topk]
self.topk = topk
self.reset()
def reset(self):
self.avg_meters = {"top{}".format(k): AverageMeter("top{}".format(k)) for k in self.topk}
def forward(self, x, label):
......@@ -316,6 +319,9 @@ class HammingDistance(MultiLabelMetric):
def __init__(self):
super().__init__()
self.reset()
def reset(self):
self.avg_meters = {"HammingDistance": AverageMeter("HammingDistance")}
def forward(self, output, target):
......@@ -343,6 +349,10 @@ class AccuracyScore(MultiLabelMetric):
assert base in ["sample", "label"
], 'must be one of ["sample", "label"]'
self.base = base
self.reset()
def reset(self):
self.avg_meters = {"AccuracyScore": AverageMeter("AccuracyScore")}
def forward(self, output, target):
preds = super()._multi_hot_encode(output)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册