From 6b173e630da2103b4430785532a683c65ac8a3d8 Mon Sep 17 00:00:00 2001 From: "Felix.Chen" Date: Wed, 27 Jul 2022 10:35:03 +0800 Subject: [PATCH] add configurable ignore classes for KIEMetric --- configs/kie/kie_unet_sdmgr.yml | 2 ++ ppocr/metrics/kie_metric.py | 8 ++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/configs/kie/kie_unet_sdmgr.yml b/configs/kie/kie_unet_sdmgr.yml index a6968aaa..0d20bd57 100644 --- a/configs/kie/kie_unet_sdmgr.yml +++ b/configs/kie/kie_unet_sdmgr.yml @@ -54,6 +54,8 @@ PostProcess: Metric: name: KIEMetric main_indicator: hmean + # Classes that will be ignored while computing F1 score. + ignore_classes: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25] Train: dataset: diff --git a/ppocr/metrics/kie_metric.py b/ppocr/metrics/kie_metric.py index 28ab22b8..93e3d6fb 100644 --- a/ppocr/metrics/kie_metric.py +++ b/ppocr/metrics/kie_metric.py @@ -24,8 +24,12 @@ __all__ = ['KIEMetric'] class KIEMetric(object): - def __init__(self, main_indicator='hmean', **kwargs): + def __init__(self, + main_indicator='hmean', + ignore_classes=[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25], + **kwargs): self.main_indicator = main_indicator + self.ignore_classes = ignore_classes self.reset() self.node = [] self.gt = [] @@ -40,7 +44,7 @@ class KIEMetric(object): # self.results.append(result) def compute_f1_score(self, preds, gts): - ignores = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25] + ignores = self.ignore_classes C = preds.shape[1] classes = np.array(sorted(set(range(C)) - set(ignores))) hist = np.bincount( -- GitLab