diff --git a/ppocr/metrics/cls_metric.py b/ppocr/metrics/cls_metric.py index 09817200234dc8d8b5d091ebbe33f07f4aad2cf6..6c077518ce205d4ec4d426aaedb8c0af880122ee 100644 --- a/ppocr/metrics/cls_metric.py +++ b/ppocr/metrics/cls_metric.py @@ -16,6 +16,7 @@ class ClsMetric(object): def __init__(self, main_indicator='acc', **kwargs): self.main_indicator = main_indicator + self.eps = 1e-5 self.reset() def __call__(self, pred_label, *args, **kwargs): @@ -28,7 +29,7 @@ class ClsMetric(object): all_num += 1 self.correct_num += correct_num self.all_num += all_num - return {'acc': correct_num / all_num, } + return {'acc': correct_num / (all_num + self.eps), } def get_metric(self): """ @@ -36,7 +37,7 @@ class ClsMetric(object): 'acc': 0 } """ - acc = self.correct_num / self.all_num + acc = self.correct_num / (self.all_num + self.eps) self.reset() return {'acc': acc} diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index b0ccd974f24f1c7e0c9a8e1d414373021c4288e6..b047bbcb972cadf227daaeb8797c46095ac0af43 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -20,6 +20,7 @@ class RecMetric(object): def __init__(self, main_indicator='acc', is_filter=False, **kwargs): self.main_indicator = main_indicator self.is_filter = is_filter + self.eps = 1e-5 self.reset() def _normalize_text(self, text): @@ -47,8 +48,8 @@ class RecMetric(object): self.all_num += all_num self.norm_edit_dis += norm_edit_dis return { - 'acc': correct_num / all_num, - 'norm_edit_dis': 1 - norm_edit_dis / (all_num + 1e-3) + 'acc': correct_num / (all_num + self.eps), + 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps) } def get_metric(self): @@ -58,8 +59,8 @@ class RecMetric(object): 'norm_edit_dis': 0, } """ - acc = 1.0 * self.correct_num / (self.all_num + 1e-3) - norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + 1e-3) + acc = 1.0 * self.correct_num / (self.all_num + self.eps) + norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps) self.reset() return {'acc': acc, 'norm_edit_dis': norm_edit_dis} diff --git a/ppocr/metrics/table_metric.py b/ppocr/metrics/table_metric.py index 80d1c789ecc3979bd4c33620af91ccd28012f7a8..ca4d6474202b4e85cadf86ccb2fe2726c7fa9aeb 100644 --- a/ppocr/metrics/table_metric.py +++ b/ppocr/metrics/table_metric.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import numpy as np + + class TableMetric(object): def __init__(self, main_indicator='acc', **kwargs): self.main_indicator = main_indicator + self.eps = 1e-5 self.reset() def __call__(self, pred, batch, *args, **kwargs): @@ -31,9 +34,7 @@ class TableMetric(object): correct_num += 1 self.correct_num += correct_num self.all_num += all_num - return { - 'acc': correct_num * 1.0 / all_num, - } + return {'acc': correct_num * 1.0 / (all_num + self.eps), } def get_metric(self): """ @@ -41,7 +42,7 @@ class TableMetric(object): 'acc': 0, } """ - acc = 1.0 * self.correct_num / self.all_num + acc = 1.0 * self.correct_num / (self.all_num + self.eps) self.reset() return {'acc': acc}