From e0d13050eee038f6f5a7a3fcfc60ef3822c260e8 Mon Sep 17 00:00:00 2001 From: Jethong <1147925384@qq.com> Date: Wed, 21 Apr 2021 14:15:51 +0800 Subject: [PATCH] add metric mode --- configs/e2e/e2e_r50_vd_pg.yml | 1 + ppocr/data/imaug/label_ops.py | 26 ++++++++++++++++++++++---- ppocr/metrics/e2e_metric.py | 5 ++--- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml index 205a46a5..4c9e5387 100644 --- a/configs/e2e/e2e_r50_vd_pg.yml +++ b/configs/e2e/e2e_r50_vd_pg.yml @@ -60,6 +60,7 @@ PostProcess: name: PGPostProcess score_thresh: 0.5 mode: fast # fast or slow two ways + Metric: name: E2EMetric mode: A # A or B diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index aae37ad9..0e3e048c 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -199,15 +199,33 @@ class E2ELabelEncode_test(BaseRecLabelEncode): character_type, use_space_char) def __call__(self, data): - texts = data['texts'] + import json + padnum = len(self.dict) + label = data['label'] + label = json.loads(label) + nBox = len(label) + boxes, txts, txt_tags = [], [], [] + for bno in range(0, nBox): + box = label[bno]['points'] + txt = label[bno]['transcription'] + boxes.append(box) + txts.append(txt) + if txt in ['*', '###']: + txt_tags.append(True) + else: + txt_tags.append(False) + boxes = np.array(boxes, dtype=np.float32) + txt_tags = np.array(txt_tags, dtype=np.bool) + data['polys'] = boxes + data['ignore_tags'] = txt_tags temp_texts = [] - for text in texts: + for text in txts: text = text.lower() text = self.encode(text) if text is None: return None - text = text + [36] * (self.max_text_len - len(text) - ) # use 36 to pad + text = text + [padnum] * (self.max_text_len - len(text) + ) # use 36 to pad temp_texts.append(text) data['texts'] = np.array(temp_texts) return data diff --git a/ppocr/metrics/e2e_metric.py b/ppocr/metrics/e2e_metric.py index aeef43f9..41b7ac2b 100644 --- a/ppocr/metrics/e2e_metric.py +++ b/ppocr/metrics/e2e_metric.py @@ -39,7 +39,7 @@ class E2EMetric(object): def __call__(self, preds, batch, **kwargs): if self.mode == 'A': gt_polyons_batch = batch[2] - temp_gt_strs_batch = batch[3] + temp_gt_strs_batch = batch[3][0] ignore_tags_batch = batch[4] gt_strs_batch = [] @@ -51,8 +51,7 @@ class E2EMetric(object): gt_strs_batch.append(t) for pred, gt_polyons, gt_strs, ignore_tags in zip( - [preds], [gt_polyons_batch], [gt_strs_batch], - ignore_tags_batch): + [preds], gt_polyons_batch, [gt_strs_batch], ignore_tags_batch): # prepare gt gt_info_list = [{ 'points': gt_polyon, -- GitLab