diff --git a/configs/kie/kie_unet_sdmgr.yml b/configs/kie/kie_unet_sdmgr.yml index a6968aaa3aa7a717a848416efc5ccc567f774b4d..0d20bd57e02b2197172f2ca410d5740305b2f799 100644 --- a/configs/kie/kie_unet_sdmgr.yml +++ b/configs/kie/kie_unet_sdmgr.yml @@ -54,6 +54,8 @@ PostProcess: Metric: name: KIEMetric main_indicator: hmean + # Classes that will be ignored while computing F1 score. + ignore_classes: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25] Train: dataset: diff --git a/ppocr/metrics/kie_metric.py b/ppocr/metrics/kie_metric.py index 28ab22b807ffe4347ca95be5a44143539db4c97c..93e3d6fb57fbc9396f21f6b9b829911d3f528d3e 100644 --- a/ppocr/metrics/kie_metric.py +++ b/ppocr/metrics/kie_metric.py @@ -24,8 +24,12 @@ __all__ = ['KIEMetric'] class KIEMetric(object): - def __init__(self, main_indicator='hmean', **kwargs): + def __init__(self, + main_indicator='hmean', + ignore_classes=[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25], + **kwargs): self.main_indicator = main_indicator + self.ignore_classes = ignore_classes self.reset() self.node = [] self.gt = [] @@ -40,7 +44,7 @@ class KIEMetric(object): # self.results.append(result) def compute_f1_score(self, preds, gts): - ignores = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25] + ignores = self.ignore_classes C = preds.shape[1] classes = np.array(sorted(set(range(C)) - set(ignores))) hist = np.bincount(