From 60a1408d4eb29d5dc16e07d902d4cd691782c510 Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Wed, 19 Jan 2022 07:21:51 +0000 Subject: [PATCH] add eps --- ppocr/metrics/cls_metric.py | 5 +++-- ppocr/metrics/rec_metric.py | 9 +++++---- ppocr/metrics/table_metric.py | 9 +++++---- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/ppocr/metrics/cls_metric.py b/ppocr/metrics/cls_metric.py index 09817200..6c077518 100644 --- a/ppocr/metrics/cls_metric.py +++ b/ppocr/metrics/cls_metric.py @@ -16,6 +16,7 @@ class ClsMetric(object): def __init__(self, main_indicator='acc', **kwargs): self.main_indicator = main_indicator + self.eps = 1e-5 self.reset() def __call__(self, pred_label, *args, **kwargs): @@ -28,7 +29,7 @@ class ClsMetric(object): all_num += 1 self.correct_num += correct_num self.all_num += all_num - return {'acc': correct_num / all_num, } + return {'acc': correct_num / (all_num + self.eps), } def get_metric(self): """ @@ -36,7 +37,7 @@ class ClsMetric(object): 'acc': 0 } """ - acc = self.correct_num / self.all_num + acc = self.correct_num / (self.all_num + self.eps) self.reset() return {'acc': acc} diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index b0ccd974..b047bbcb 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -20,6 +20,7 @@ class RecMetric(object): def __init__(self, main_indicator='acc', is_filter=False, **kwargs): self.main_indicator = main_indicator self.is_filter = is_filter + self.eps = 1e-5 self.reset() def _normalize_text(self, text): @@ -47,8 +48,8 @@ class RecMetric(object): self.all_num += all_num self.norm_edit_dis += norm_edit_dis return { - 'acc': correct_num / all_num, - 'norm_edit_dis': 1 - norm_edit_dis / (all_num + 1e-3) + 'acc': correct_num / (all_num + self.eps), + 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps) } def get_metric(self): @@ -58,8 +59,8 @@ class RecMetric(object): 'norm_edit_dis': 0, } """ - acc = 1.0 * self.correct_num / (self.all_num + 1e-3) - norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + 1e-3) + acc = 1.0 * self.correct_num / (self.all_num + self.eps) + norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps) self.reset() return {'acc': acc, 'norm_edit_dis': norm_edit_dis} diff --git a/ppocr/metrics/table_metric.py b/ppocr/metrics/table_metric.py index 80d1c789..ca4d6474 100644 --- a/ppocr/metrics/table_metric.py +++ b/ppocr/metrics/table_metric.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import numpy as np + + class TableMetric(object): def __init__(self, main_indicator='acc', **kwargs): self.main_indicator = main_indicator + self.eps = 1e-5 self.reset() def __call__(self, pred, batch, *args, **kwargs): @@ -31,9 +34,7 @@ class TableMetric(object): correct_num += 1 self.correct_num += correct_num self.all_num += all_num - return { - 'acc': correct_num * 1.0 / all_num, - } + return {'acc': correct_num * 1.0 / (all_num + self.eps), } def get_metric(self): """ @@ -41,7 +42,7 @@ class TableMetric(object): 'acc': 0, } """ - acc = 1.0 * self.correct_num / self.all_num + acc = 1.0 * self.correct_num / (self.all_num + self.eps) self.reset() return {'acc': acc} -- GitLab