From 87a0ba6f316ef1068cfa9c28819b590e19eafc65 Mon Sep 17 00:00:00 2001 From: cuicheng01 Date: Mon, 16 May 2022 04:31:26 +0000 Subject: [PATCH] update metrics --- ppcls/metric/__init__.py | 5 +++++ ppcls/metric/metrics.py | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/ppcls/metric/__init__.py b/ppcls/metric/__init__.py index 2bf45125..fa3511c8 100644 --- a/ppcls/metric/__init__.py +++ b/ppcls/metric/__init__.py @@ -39,6 +39,7 @@ class CombinedMetrics(AvgMetrics): eval(metric_name)(**metric_params)) else: self.metric_func_list.append(eval(metric_name)()) + self.reset() def forward(self, *args, **kwargs): metric_dict = OrderedDict() @@ -54,6 +55,10 @@ class CombinedMetrics(AvgMetrics): def avg(self): return self.metric_func_list[0].avg + def reset(self): + for metric in self.metric_func_list: + if hasattr(metric, "reset"): + metric.reset() def build_metrics(config): metrics_list = CombinedMetrics(copy.deepcopy(config)) diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 62882df5..422e0b6d 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -33,6 +33,9 @@ class TopkAcc(AvgMetrics): if isinstance(topk, int): topk = [topk] self.topk = topk + self.reset() + + def reset(self): self.avg_meters = {"top{}".format(k): AverageMeter("top{}".format(k)) for k in self.topk} def forward(self, x, label): @@ -316,6 +319,9 @@ class HammingDistance(MultiLabelMetric): def __init__(self): super().__init__() + self.reset() + + def reset(self): self.avg_meters = {"HammingDistance": AverageMeter("HammingDistance")} def forward(self, output, target): @@ -343,6 +349,10 @@ class AccuracyScore(MultiLabelMetric): assert base in ["sample", "label" ], 'must be one of ["sample", "label"]' self.base = base + self.reset() + + def reset(self): + self.avg_meters = {"AccuracyScore": AverageMeter("AccuracyScore")} def forward(self, output, target): preds = super()._multi_hot_encode(output) -- GitLab