From 1f76f449db25dfc0a1695da487f8e5856935799a Mon Sep 17 00:00:00 2001 From: Jethong <1147925384@qq.com> Date: Mon, 8 Mar 2021 14:15:47 +0800 Subject: [PATCH] Add PGNet --- configs/det/det_r50_vd_sast_icdar15.yml | 5 +- configs/e2e/e2e_r50_vd_pg.yml | 122 +++ ppocr/data/__init__.py | 4 +- ppocr/data/imaug/__init__.py | 1 + ppocr/data/imaug/label_ops.py | 19 + ppocr/data/imaug/operators.py | 71 ++ ppocr/data/imaug/pg_process.py | 921 +++++++++++++++++++ ppocr/data/pgnet_dataset.py | 164 ++++ ppocr/losses/__init__.py | 5 +- ppocr/losses/e2e_pg_loss.py | 219 +++++ ppocr/metrics/__init__.py | 3 +- ppocr/metrics/e2e_metric.py | 87 ++ ppocr/metrics/eval_det_iou.py | 7 +- ppocr/modeling/backbones/__init__.py | 3 + ppocr/modeling/backbones/e2e_resnet_vd_pg.py | 267 ++++++ ppocr/modeling/heads/__init__.py | 5 +- ppocr/modeling/heads/e2e_pg_head.py | 249 +++++ ppocr/modeling/necks/__init__.py | 4 +- ppocr/modeling/necks/pg_fpn.py | 277 ++++++ ppocr/postprocess/__init__.py | 3 +- ppocr/postprocess/pg_postprocess.py | 194 ++++ ppocr/postprocess/sast_postprocess.py | 199 ++-- ppocr/utils/e2e_metric/Deteval.py | 877 ++++++++++++++++++ ppocr/utils/e2e_metric/polygon_fast.py | 71 ++ ppocr/utils/e2e_metric/tttt.py | 881 ++++++++++++++++++ ppocr/utils/e2e_utils/extract_textpoint.py | 532 +++++++++++ ppocr/utils/e2e_utils/ski_thin.py | 126 +++ ppocr/utils/e2e_utils/visual.py | 343 +++++++ tools/infer_e2e.py | 114 +++ tools/program.py | 4 +- 30 files changed, 5691 insertions(+), 86 deletions(-) create mode 100644 configs/e2e/e2e_r50_vd_pg.yml create mode 100644 ppocr/data/imaug/pg_process.py create mode 100644 ppocr/data/pgnet_dataset.py create mode 100644 ppocr/losses/e2e_pg_loss.py create mode 100644 ppocr/metrics/e2e_metric.py create mode 100644 ppocr/modeling/backbones/e2e_resnet_vd_pg.py create mode 100644 ppocr/modeling/heads/e2e_pg_head.py create mode 100644 ppocr/modeling/necks/pg_fpn.py create mode 100644 ppocr/postprocess/pg_postprocess.py create mode 100755 ppocr/utils/e2e_metric/Deteval.py create mode 100755 ppocr/utils/e2e_metric/polygon_fast.py create mode 100644 ppocr/utils/e2e_metric/tttt.py create mode 100644 ppocr/utils/e2e_utils/extract_textpoint.py create mode 100644 ppocr/utils/e2e_utils/ski_thin.py create mode 100644 ppocr/utils/e2e_utils/visual.py create mode 100755 tools/infer_e2e.py diff --git a/configs/det/det_r50_vd_sast_icdar15.yml b/configs/det/det_r50_vd_sast_icdar15.yml index c24cae90..c90327b2 100755 --- a/configs/det/det_r50_vd_sast_icdar15.yml +++ b/configs/det/det_r50_vd_sast_icdar15.yml @@ -14,12 +14,13 @@ Global: load_static_weights: True cal_metric_during_train: False pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/ - checkpoints: + checkpoints: save_inference_dir: use_visualdl: False - infer_img: + infer_img: save_res_path: ./output/sast_r50_vd_ic15/predicts_sast.txt + Architecture: model_type: det algorithm: SAST diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml new file mode 100644 index 00000000..05a13135 --- /dev/null +++ b/configs/e2e/e2e_r50_vd_pg.yml @@ -0,0 +1,122 @@ +Global: + use_gpu: False + epoch_num: 600 + log_smooth_window: 20 + print_batch_step: 2 + save_model_dir: ./output/pg_r50_vd_tt/ + save_epoch_step: 1 + # evaluation is run every 5000 iterationss after the 4000th iteration + eval_batch_step: [ 0, 1000 ] + # if pretrained_model is saved in static mode, load_static_weights must set to True + load_static_weights: False + cal_metric_during_train: False + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: + save_res_path: ./output/pg_r50_vd_tt/predicts_pg.txt + +Architecture: + model_type: e2e + algorithm: PG + Transform: + Backbone: + name: ResNet + layers: 50 + Neck: + name: PGFPN + model_name: large + Head: + name: PGHead + model_name: large + +Loss: + name: PGLoss + +#Optimizer: +# name: Adam +# beta1: 0.9 +# beta2: 0.999 +# lr: +# name: Cosine +# learning_rate: 0.001 +# warmup_epoch: 1 +# regularizer: +# name: 'L2' +# factor: 0 + +Optimizer: + name: RMSProp + lr: + name: Piecewise + learning_rate: 0.001 + decay_epochs: [ 40, 80, 120, 160, 200 ] + values: [ 0.001, 0.00033, 0.0001, 0.000033, 0.00001 ] + regularizer: + name: 'L2' + factor: 0.00005 + +PostProcess: + name: PGPostProcess + score_thresh: 0.8 + cover_thresh: 0.1 + nms_thresh: 0.2 + +Metric: + name: E2EMetric + main_indicator: hmean + +Train: + dataset: + name: PGDateSet + label_file_list: + ratio_list: + data_format: textnet # textnet/partvgg + Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - PGProcessTrain: + batch_size: 14 + data_format: icdar + tcl_len: 64 + min_crop_size: 24 + min_text_size: 4 + max_text_size: 512 + - KeepKeys: + keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order + loader: + shuffle: True + drop_last: True + batch_size_per_card: 1 + num_workers: 8 + +Eval: + dataset: + name: PGDateSet + data_dir: ./train_data/ + label_file_list: + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - E2ELabelEncode: + label_list: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ] + - E2EResizeForTest: + valid_set: totaltext + max_side_len: 768 + - NormalizeImage: + scale: 1./255. + mean: [ 0.485, 0.456, 0.406 ] + std: [ 0.229, 0.224, 0.225 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags' ] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 \ No newline at end of file diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py index 7cb50d7a..bcfbf489 100644 --- a/ppocr/data/__init__.py +++ b/ppocr/data/__init__.py @@ -34,6 +34,7 @@ import paddle.distributed as dist from ppocr.data.imaug import transform, create_operators from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.lmdb_dataset import LMDBDataSet +from ppocr.data.pgnet_dataset import PGDateSet __all__ = ['build_dataloader', 'transform', 'create_operators'] @@ -54,7 +55,8 @@ signal.signal(signal.SIGTERM, term_mp) def build_dataloader(config, mode, device, logger, seed=None): config = copy.deepcopy(config) - support_dict = ['SimpleDataSet', 'LMDBDataSet'] + + support_dict = ['SimpleDataSet', 'LMDBDateSet', 'PGDateSet'] module_name = config[mode]['dataset']['name'] assert module_name in support_dict, Exception( 'DataSet only support {}'.format(support_dict)) diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 250ac75e..a808fd58 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -28,6 +28,7 @@ from .label_ops import * from .east_process import * from .sast_process import * +from .pg_process import * def transform(data, ops=None): diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 7a32d870..4cae2337 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -34,6 +34,25 @@ class ClsLabelEncode(object): return data +class E2ELabelEncode(object): + def __init__(self, label_list, **kwargs): + self.label_list = label_list + + def __call__(self, data): + text_label_index_list, temp_text = [], [] + texts = data['strs'] + for text in texts: + text = text.upper() + temp_text = [] + for c_ in text: + if c_ in self.label_list: + temp_text.append(self.label_list.index(c_)) + temp_text = temp_text + [36] * (50 - len(temp_text)) + text_label_index_list.append(temp_text) + data['strs'] = np.array(text_label_index_list) + return data + + class DetLabelEncode(object): def __init__(self, **kwargs): pass diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index eacfdf3b..d4cdad28 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -223,3 +223,74 @@ class DetResizeForTest(object): ratio_w = resize_w / float(w) return img, [ratio_h, ratio_w] + + +class E2EResizeForTest(object): + def __init__(self, **kwargs): + super(E2EResizeForTest, self).__init__() + self.max_side_len = kwargs['max_side_len'] + self.valid_set = kwargs['valid_set'] + + def __call__(self, data): + img = data['image'] + src_h, src_w, _ = img.shape + if self.valid_set == 'totaltext': + im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext( + img, max_side_len=self.max_side_len) + else: + im_resized, (ratio_h, ratio_w) = self.resize_image( + img, max_side_len=self.max_side_len) + data['image'] = im_resized + data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) + return data + + def resize_image_for_totaltext(self, im, max_side_len=512): + """ + """ + h, w, _ = im.shape + + resize_w = w + resize_h = h + ratio = 1.25 + if h * ratio > max_side_len: + ratio = float(max_side_len) / resize_h + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + return im, (ratio_h, ratio_w) + + def resize_image(self, im, max_side_len=512): + """ + resize image to a size multiple of max_stride which is required by the network + :param im: the resized image + :param max_side_len: limit of max image size to avoid out of memory in gpu + :return: the resized image and the resize ratio + """ + h, w, _ = im.shape + + resize_w = w + resize_h = h + + # Fix the longer side + if resize_h > resize_w: + ratio = float(max_side_len) / resize_h + else: + ratio = float(max_side_len) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return im, (ratio_h, ratio_w) diff --git a/ppocr/data/imaug/pg_process.py b/ppocr/data/imaug/pg_process.py new file mode 100644 index 00000000..60abf194 --- /dev/null +++ b/ppocr/data/imaug/pg_process.py @@ -0,0 +1,921 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import cv2 +import numpy as np +import os + +__all__ = ['PGProcessTrain'] + + +class PGProcessTrain(object): + def __init__(self, + batch_size=14, + data_format='icdar', + tcl_len=64, + min_crop_size=24, + min_text_size=10, + max_text_size=512, + **kwargs): + self.batch_size = batch_size + self.data_format = data_format + self.tcl_len = tcl_len + self.min_crop_size = min_crop_size + self.min_text_size = min_text_size + self.max_text_size = max_text_size + self.Lexicon_Table = [ + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', + 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', + 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' + ] + self.img_id = 0 + + def quad_area(self, poly): + """ + compute area of a polygon + :param poly: + :return: + """ + edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]), + (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]), + (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]), + (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])] + return np.sum(edge) / 2. + + def gen_quad_from_poly(self, poly): + """ + Generate min area quad from poly. + """ + point_num = poly.shape[0] + min_area_quad = np.zeros((4, 2), dtype=np.float32) + if True: + rect = cv2.minAreaRect(poly.astype( + np.int32)) # (center (x,y), (width, height), angle of rotation) + center_point = rect[0] + box = np.array(cv2.boxPoints(rect)) + + first_point_idx = 0 + min_dist = 1e4 + for i in range(4): + dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \ + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \ + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \ + np.linalg.norm(box[(i + 3) % 4] - poly[-1]) + if dist < min_dist: + min_dist = dist + first_point_idx = i + for i in range(4): + min_area_quad[i] = box[(first_point_idx + i) % 4] + + return min_area_quad + + def check_and_validate_polys(self, polys, tags, xxx_todo_changeme): + """ + check so that the text poly is in the same direction, + and also filter some invalid polygons + :param polys: + :param tags: + :return: + """ + (h, w) = xxx_todo_changeme + if polys.shape[0] == 0: + return polys, np.array([]), np.array([]) + polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) + polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1) + + validated_polys = [] + validated_tags = [] + hv_tags = [] + for poly, tag in zip(polys, tags): + quad = self.gen_quad_from_poly(poly) + p_area = self.quad_area(quad) + if abs(p_area) < 1: + print('invalid poly') + continue + if p_area > 0: + if tag == False: + print('poly in wrong direction') + tag = True # reversed cases should be ignore + poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, + 1), :] + quad = quad[(0, 3, 2, 1), :] + + len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] - + quad[2]) + len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - + quad[2]) + hv_tag = 1 + + if len_w * 2.0 < len_h: + hv_tag = 0 + + validated_polys.append(poly) + validated_tags.append(tag) + hv_tags.append(hv_tag) + return np.array(validated_polys), np.array(validated_tags), np.array( + hv_tags) + + def crop_area(self, + im, + polys, + tags, + hv_tags, + txts, + crop_background=False, + max_tries=25): + """ + make random crop from the input image + :param im: + :param polys: [b,4,2] + :param tags: + :param crop_background: + :param max_tries: 50 -> 25 + :return: + """ + h, w, _ = im.shape + pad_h = h // 10 + pad_w = w // 10 + h_array = np.zeros((h + pad_h * 2), dtype=np.int32) + w_array = np.zeros((w + pad_w * 2), dtype=np.int32) + for poly in polys: + poly = np.round(poly, decimals=0).astype(np.int32) + minx = np.min(poly[:, 0]) + maxx = np.max(poly[:, 0]) + w_array[minx + pad_w:maxx + pad_w] = 1 + miny = np.min(poly[:, 1]) + maxy = np.max(poly[:, 1]) + h_array[miny + pad_h:maxy + pad_h] = 1 + # ensure the cropped area not across a text + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + if len(h_axis) == 0 or len(w_axis) == 0: + return im, polys, tags, hv_tags, txts + for i in range(max_tries): + xx = np.random.choice(w_axis, size=2) + xmin = np.min(xx) - pad_w + xmax = np.max(xx) - pad_w + xmin = np.clip(xmin, 0, w - 1) + xmax = np.clip(xmax, 0, w - 1) + yy = np.random.choice(h_axis, size=2) + ymin = np.min(yy) - pad_h + ymax = np.max(yy) - pad_h + ymin = np.clip(ymin, 0, h - 1) + ymax = np.clip(ymax, 0, h - 1) + if xmax - xmin < self.min_crop_size or \ + ymax - ymin < self.min_crop_size: + continue + if polys.shape[0] != 0: + poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \ + & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax) + selected_polys = np.where( + np.sum(poly_axis_in_area, axis=1) == 4)[0] + else: + selected_polys = [] + if len(selected_polys) == 0: + # no text in this area + if crop_background: + txts_tmp = [] + for selected_poly in selected_polys: + txts_tmp.append(txts[selected_poly]) + txts = txts_tmp + # print(1111) + return im[ymin: ymax + 1, xmin: xmax + 1, :], \ + polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts + else: + continue + im = im[ymin:ymax + 1, xmin:xmax + 1, :] + polys = polys[selected_polys] + tags = tags[selected_polys] + hv_tags = hv_tags[selected_polys] + txts_tmp = [] + for selected_poly in selected_polys: + txts_tmp.append(txts[selected_poly]) + txts = txts_tmp + polys[:, :, 0] -= xmin + polys[:, :, 1] -= ymin + return im, polys, tags, hv_tags, txts + + return im, polys, tags, hv_tags, txts + + def fit_and_gather_tcl_points_v2(self, + min_area_quad, + poly, + max_h, + max_w, + fixed_point_num=64, + img_id=0, + reference_height=3): + """ + Find the center point of poly as key_points, then fit and gather. + """ + key_point_xys = [] + point_num = poly.shape[0] + for idx in range(point_num // 2): + center_point = (poly[idx] + poly[point_num - 1 - idx]) / 2.0 + key_point_xys.append(center_point) + + tmp_image = np.zeros( + shape=( + max_h, + max_w, ), dtype='float32') + cv2.polylines(tmp_image, [np.array(key_point_xys).astype('int32')], + False, 1.0) + ys, xs = np.where(tmp_image > 0) + xy_text = np.array(list(zip(xs, ys)), dtype='float32') + + # left_center_pt = np.array(key_point_xys[0]).reshape(1, 2) + # right_center_pt = np.array(key_point_xys[-1]).reshape(1, 2) + left_center_pt = ( + (min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2) + right_center_pt = ( + (min_area_quad[1] - min_area_quad[2]) / 2.0).reshape(1, 2) + proj_unit_vec = (right_center_pt - left_center_pt) / ( + np.linalg.norm(right_center_pt - left_center_pt) + 1e-6) + proj_unit_vec_tile = np.tile(proj_unit_vec, + (xy_text.shape[0], 1)) # (n, 2) + left_center_pt_tile = np.tile(left_center_pt, + (xy_text.shape[0], 1)) # (n, 2) + xy_text_to_left_center = xy_text - left_center_pt_tile + proj_value = np.sum(xy_text_to_left_center * proj_unit_vec_tile, axis=1) + xy_text = xy_text[np.argsort(proj_value)] + + # convert to np and keep the num of point not greater then fixed_point_num + pos_info = np.array(xy_text).reshape(-1, 2)[:, ::-1] # xy-> yx + point_num = len(pos_info) + if point_num > fixed_point_num: + keep_ids = [ + int((point_num * 1.0 / fixed_point_num) * x) + for x in range(fixed_point_num) + ] + pos_info = pos_info[keep_ids, :] + + keep = int(min(len(pos_info), fixed_point_num)) + if np.random.rand() < 0.2 and reference_height >= 3: + dl = (np.random.rand(keep) - 0.5) * reference_height * 0.3 + random_float = np.array([1, 0]).reshape([1, 2]) * dl.reshape( + [keep, 1]) + pos_info += random_float + pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1) + pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1) + + # padding to fixed length + pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32) + pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id + pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32) + pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32) + pos_m[:keep] = 1.0 + return pos_l, pos_m + + def generate_direction_map(self, poly_quads, n_char, direction_map): + """ + """ + width_list = [] + height_list = [] + for quad in poly_quads: + quad_w = (np.linalg.norm(quad[0] - quad[1]) + + np.linalg.norm(quad[2] - quad[3])) / 2.0 + quad_h = (np.linalg.norm(quad[0] - quad[3]) + + np.linalg.norm(quad[2] - quad[1])) / 2.0 + width_list.append(quad_w) + height_list.append(quad_h) + norm_width = max(sum(width_list) / n_char, 1.0) + average_height = max(sum(height_list) / len(height_list), 1.0) + + for quad in poly_quads: + direct_vector_full = ( + (quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0 + direct_vector = direct_vector_full / ( + np.linalg.norm(direct_vector_full) + 1e-6) * norm_width + direction_label = tuple( + map(float, + [direct_vector[0], direct_vector[1], 1.0 / average_height])) + cv2.fillPoly(direction_map, + quad.round().astype(np.int32)[np.newaxis, :, :], + direction_label) + return direction_map + + def calculate_average_height(self, poly_quads): + """ + """ + height_list = [] + for quad in poly_quads: + quad_h = (np.linalg.norm(quad[0] - quad[3]) + + np.linalg.norm(quad[2] - quad[1])) / 2.0 + height_list.append(quad_h) + average_height = max(sum(height_list) / len(height_list), 1.0) + return average_height + + def encode(self, text): + text_list = [] + for char in text: + if char not in self.dict: + continue + text_list.append([self.dict[char]]) + if len(text_list) == 0: + return None + return text_list + + def generate_tcl_ctc_label(self, + h, + w, + polys, + tags, + text_strs, + ds_ratio, + tcl_ratio=0.3, + shrink_ratio_of_width=0.15): + """ + Generate polygon. + """ + score_map_big = np.zeros( + ( + h, + w, ), dtype=np.float32) + h, w = int(h * ds_ratio), int(w * ds_ratio) + polys = polys * ds_ratio + + score_map = np.zeros( + ( + h, + w, ), dtype=np.float32) + score_label_map = np.zeros( + ( + h, + w, ), dtype=np.float32) + tbo_map = np.zeros((h, w, 5), dtype=np.float32) + training_mask = np.ones( + ( + h, + w, ), dtype=np.float32) + direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape( + [1, 1, 3]).astype(np.float32) + + label_idx = 0 + score_label_map_text_label_list = [] + pos_list, pos_mask, label_list = [], [], [] + for poly_idx, poly_tag in enumerate(zip(polys, tags)): + poly = poly_tag[0] + tag = poly_tag[1] + + # generate min_area_quad + min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly) + min_area_quad_h = 0.5 * ( + np.linalg.norm(min_area_quad[0] - min_area_quad[3]) + + np.linalg.norm(min_area_quad[1] - min_area_quad[2])) + min_area_quad_w = 0.5 * ( + np.linalg.norm(min_area_quad[0] - min_area_quad[1]) + + np.linalg.norm(min_area_quad[2] - min_area_quad[3])) + + if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \ + or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio: + continue + + if tag: + # continue + cv2.fillPoly(training_mask, + poly.astype(np.int32)[np.newaxis, :, :], 0.15) + else: + text_label = text_strs[poly_idx] + text_label = self.prepare_text_label(text_label, + self.Lexicon_Table) + # text = text.decode('utf-8') + # text_label_index_list = self.encode(text) + + text_label_index_list = [[self.Lexicon_Table.index(c_)] + for c_ in text_label + if c_ in self.Lexicon_Table] + if len(text_label_index_list) < 1: + continue + + tcl_poly = self.poly2tcl(poly, tcl_ratio) + tcl_quads = self.poly2quads(tcl_poly) + poly_quads = self.poly2quads(poly) + # stcl map + stcl_quads, quad_index = self.shrink_poly_along_width( + tcl_quads, + shrink_ratio_of_width=shrink_ratio_of_width, + expand_height_ratio=1.0 / tcl_ratio) + # generate tcl map + cv2.fillPoly(score_map, + np.round(stcl_quads).astype(np.int32), 1.0) + cv2.fillPoly(score_map_big, + np.round(stcl_quads / ds_ratio).astype(np.int32), + 1.0) + + # generate tbo map + # tbo_tcl_poly = poly2tcl(poly, 0.5) + # tbo_tcl_quads = poly2quads(tbo_tcl_poly) + # for idx, quad in enumerate(tbo_tcl_quads): + for idx, quad in enumerate(stcl_quads): + quad_mask = np.zeros((h, w), dtype=np.float32) + quad_mask = cv2.fillPoly( + quad_mask, + np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0) + tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]], + quad_mask, tbo_map) + + # score label map and score_label_map_text_label_list for refine + if label_idx == 0: + text_pos_list_ = [[len(self.Lexicon_Table)], ] + score_label_map_text_label_list.append(text_pos_list_) + + label_idx += 1 + # cv2.fillPoly(score_label_map, np.round(poly_quads[np.newaxis, :, :]).astype(np.int32), label_idx) + cv2.fillPoly(score_label_map, + np.round(poly_quads).astype(np.int32), label_idx) + score_label_map_text_label_list.append(text_label_index_list) + + # direction info, fix-me + n_char = len(text_label_index_list) + direction_map = self.generate_direction_map(poly_quads, n_char, + direction_map) + + # pos info + average_shrink_height = self.calculate_average_height( + stcl_quads) + pos_l, pos_m = self.fit_and_gather_tcl_points_v2( + min_area_quad, + poly, + max_h=h, + max_w=w, + fixed_point_num=64, + img_id=self.img_id, + reference_height=average_shrink_height) + + label_l = text_label_index_list + if len(text_label_index_list) < 2: + continue + + pos_list.append(pos_l) + pos_mask.append(pos_m) + label_list.append(label_l) + + # use big score_map for smooth tcl lines + score_map_big_resized = cv2.resize( + score_map_big, dsize=None, fx=ds_ratio, fy=ds_ratio) + score_map = np.array(score_map_big_resized > 1e-3, dtype='float32') + + return score_map, score_label_map, tbo_map, direction_map, training_mask, \ + pos_list, pos_mask, label_list, score_label_map_text_label_list + + def adjust_point(self, poly): + """ + adjust point order. + """ + point_num = poly.shape[0] + if point_num == 4: + len_1 = np.linalg.norm(poly[0] - poly[1]) + len_2 = np.linalg.norm(poly[1] - poly[2]) + len_3 = np.linalg.norm(poly[2] - poly[3]) + len_4 = np.linalg.norm(poly[3] - poly[0]) + + if (len_1 + len_3) * 1.5 < (len_2 + len_4): + poly = poly[[1, 2, 3, 0], :] + + elif point_num > 4: + vector_1 = poly[0] - poly[1] + vector_2 = poly[1] - poly[2] + cos_theta = np.dot(vector_1, vector_2) / ( + np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6) + theta = np.arccos(np.round(cos_theta, decimals=4)) + + if abs(theta) > (70 / 180 * math.pi): + index = list(range(1, point_num)) + [0] + poly = poly[np.array(index), :] + return poly + + def gen_min_area_quad_from_poly(self, poly): + """ + Generate min area quad from poly. + """ + point_num = poly.shape[0] + min_area_quad = np.zeros((4, 2), dtype=np.float32) + if point_num == 4: + min_area_quad = poly + center_point = np.sum(poly, axis=0) / 4 + else: + rect = cv2.minAreaRect(poly.astype( + np.int32)) # (center (x,y), (width, height), angle of rotation) + center_point = rect[0] + box = np.array(cv2.boxPoints(rect)) + + first_point_idx = 0 + min_dist = 1e4 + for i in range(4): + dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \ + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \ + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \ + np.linalg.norm(box[(i + 3) % 4] - poly[-1]) + if dist < min_dist: + min_dist = dist + first_point_idx = i + + for i in range(4): + min_area_quad[i] = box[(first_point_idx + i) % 4] + + return min_area_quad, center_point + + def shrink_quad_along_width(self, + quad, + begin_width_ratio=0., + end_width_ratio=1.): + """ + Generate shrink_quad_along_width. + """ + ratio_pair = np.array( + [[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair + p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair + return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) + + def shrink_poly_along_width(self, + quads, + shrink_ratio_of_width, + expand_height_ratio=1.0): + """ + shrink poly with given length. + """ + upper_edge_list = [] + + def get_cut_info(edge_len_list, cut_len): + for idx, edge_len in enumerate(edge_len_list): + cut_len -= edge_len + if cut_len <= 0.000001: + ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx] + return idx, ratio + + for quad in quads: + upper_edge_len = np.linalg.norm(quad[0] - quad[1]) + upper_edge_list.append(upper_edge_len) + + # length of left edge and right edge. + left_length = np.linalg.norm(quads[0][0] - quads[0][ + 3]) * expand_height_ratio + right_length = np.linalg.norm(quads[-1][1] - quads[-1][ + 2]) * expand_height_ratio + + shrink_length = min(left_length, right_length, + sum(upper_edge_list)) * shrink_ratio_of_width + # shrinking length + upper_len_left = shrink_length + upper_len_right = sum(upper_edge_list) - shrink_length + + left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left) + left_quad = self.shrink_quad_along_width( + quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1) + right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right) + right_quad = self.shrink_quad_along_width( + quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio) + + out_quad_list = [] + if left_idx == right_idx: + out_quad_list.append( + [left_quad[0], right_quad[1], right_quad[2], left_quad[3]]) + else: + out_quad_list.append(left_quad) + for idx in range(left_idx + 1, right_idx): + out_quad_list.append(quads[idx]) + out_quad_list.append(right_quad) + + return np.array(out_quad_list), list(range(left_idx, right_idx + 1)) + + def prepare_text_label(self, label_str, Lexicon_Table): + """ + Prepare text lablel by given Lexicon_Table. + """ + if len(Lexicon_Table) == 36: + return label_str.upper() + else: + return label_str + + def vector_angle(self, A, B): + """ + Calculate the angle between vector AB and x-axis positive direction. + """ + AB = np.array([B[1] - A[1], B[0] - A[0]]) + return np.arctan2(*AB) + + def theta_line_cross_point(self, theta, point): + """ + Calculate the line through given point and angle in ax + by + c =0 form. + """ + x, y = point + cos = np.cos(theta) + sin = np.sin(theta) + return [sin, -cos, cos * y - sin * x] + + def line_cross_two_point(self, A, B): + """ + Calculate the line through given point A and B in ax + by + c =0 form. + """ + angle = self.vector_angle(A, B) + return self.theta_line_cross_point(angle, A) + + def average_angle(self, poly): + """ + Calculate the average angle between left and right edge in given poly. + """ + p0, p1, p2, p3 = poly + angle30 = self.vector_angle(p3, p0) + angle21 = self.vector_angle(p2, p1) + return (angle30 + angle21) / 2 + + def line_cross_point(self, line1, line2): + """ + line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2 + """ + a1, b1, c1 = line1 + a2, b2, c2 = line2 + d = a1 * b2 - a2 * b1 + + if d == 0: + # print("line1", line1) + # print("line2", line2) + print('Cross point does not exist') + return np.array([0, 0], dtype=np.float32) + else: + x = (b1 * c2 - b2 * c1) / d + y = (a2 * c1 - a1 * c2) / d + + return np.array([x, y], dtype=np.float32) + + def quad2tcl(self, poly, ratio): + """ + Generate center line by poly clock-wise point. (4, 2) + """ + ratio_pair = np.array( + [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32) + p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair + p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair + return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]]) + + def poly2tcl(self, poly, ratio): + """ + Generate center line by poly clock-wise point. + """ + ratio_pair = np.array( + [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32) + tcl_poly = np.zeros_like(poly) + point_num = poly.shape[0] + + for idx in range(point_num // 2): + point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx] + ) * ratio_pair + tcl_poly[idx] = point_pair[0] + tcl_poly[point_num - 1 - idx] = point_pair[1] + return tcl_poly + + def gen_quad_tbo(self, quad, tcl_mask, tbo_map): + """ + Generate tbo_map for give quad. + """ + # upper and lower line function: ax + by + c = 0; + up_line = self.line_cross_two_point(quad[0], quad[1]) + lower_line = self.line_cross_two_point(quad[3], quad[2]) + + quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) + + np.linalg.norm(quad[1] - quad[2])) + quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) + + np.linalg.norm(quad[2] - quad[3])) + + # average angle of left and right line. + angle = self.average_angle(quad) + + xy_in_poly = np.argwhere(tcl_mask == 1) + for y, x in xy_in_poly: + point = (x, y) + line = self.theta_line_cross_point(angle, point) + cross_point_upper = self.line_cross_point(up_line, line) + cross_point_lower = self.line_cross_point(lower_line, line) + ##FIX, offset reverse + upper_offset_x, upper_offset_y = cross_point_upper - point + lower_offset_x, lower_offset_y = cross_point_lower - point + tbo_map[y, x, 0] = upper_offset_y + tbo_map[y, x, 1] = upper_offset_x + tbo_map[y, x, 2] = lower_offset_y + tbo_map[y, x, 3] = lower_offset_x + tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2 + return tbo_map + + def poly2quads(self, poly): + """ + Split poly into quads. + """ + quad_list = [] + point_num = poly.shape[0] + + # point pair + point_pair_list = [] + for idx in range(point_num // 2): + point_pair = [poly[idx], poly[point_num - 1 - idx]] + point_pair_list.append(point_pair) + + quad_num = point_num // 2 - 1 + for idx in range(quad_num): + # reshape and adjust to clock-wise + quad_list.append((np.array(point_pair_list)[[idx, idx + 1]] + ).reshape(4, 2)[[0, 2, 3, 1]]) + + return np.array(quad_list) + + def rotate_im_poly(self, im, text_polys): + """ + rotate image with 90 / 180 / 270 degre + """ + im_w, im_h = im.shape[1], im.shape[0] + dst_im = im.copy() + dst_polys = [] + rand_degree_ratio = np.random.rand() + rand_degree_cnt = 1 + if rand_degree_ratio > 0.5: + rand_degree_cnt = 3 + for i in range(rand_degree_cnt): + dst_im = np.rot90(dst_im) + rot_degree = -90 * rand_degree_cnt + rot_angle = rot_degree * math.pi / 180.0 + n_poly = text_polys.shape[0] + cx, cy = 0.5 * im_w, 0.5 * im_h + ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0] + for i in range(n_poly): + wordBB = text_polys[i] + poly = [] + for j in range(4): # 16->4 + sx, sy = wordBB[j][0], wordBB[j][1] + dx = math.cos(rot_angle) * (sx - cx) - math.sin(rot_angle) * ( + sy - cy) + ncx + dy = math.sin(rot_angle) * (sx - cx) + math.cos(rot_angle) * ( + sy - cy) + ncy + poly.append([dx, dy]) + dst_polys.append(poly) + return dst_im, np.array(dst_polys, dtype=np.float32) + + def __call__(self, data): + input_size = 512 + im = data['image'] + text_polys = data['polys'] + text_tags = data['tags'] + text_strs = data['strs'] + h, w, _ = im.shape + text_polys, text_tags, hv_tags = self.check_and_validate_polys( + text_polys, text_tags, (h, w)) + if text_polys.shape[0] <= 0: + return None + # set aspect ratio and keep area fix + asp_scales = np.arange(1.0, 1.55, 0.1) + asp_scale = np.random.choice(asp_scales) + if np.random.rand() < 0.5: + asp_scale = 1.0 / asp_scale + asp_scale = math.sqrt(asp_scale) + + asp_wx = asp_scale + asp_hy = 1.0 / asp_scale + im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy) + text_polys[:, :, 0] *= asp_wx + text_polys[:, :, 1] *= asp_hy + + h, w, _ = im.shape + if max(h, w) > 2048: + rd_scale = 2048.0 / max(h, w) + im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) + text_polys *= rd_scale + h, w, _ = im.shape + if min(h, w) < 16: + return None + + # no background + im, text_polys, text_tags, hv_tags, text_strs = self.crop_area( + im, + text_polys, + text_tags, + hv_tags, + text_strs, + crop_background=False) + + if text_polys.shape[0] == 0: + return None + # # continue for all ignore case + if np.sum((text_tags * 1.0)) >= text_tags.size: + return None + new_h, new_w, _ = im.shape + if (new_h is None) or (new_w is None): + return None + # resize image + std_ratio = float(input_size) / max(new_w, new_h) + rand_scales = np.array( + [0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0]) + rz_scale = std_ratio * np.random.choice(rand_scales) + im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale) + text_polys[:, :, 0] *= rz_scale + text_polys[:, :, 1] *= rz_scale + + # add gaussian blur + if np.random.rand() < 0.1 * 0.5: + ks = np.random.permutation(5)[0] + 1 + ks = int(ks / 2) * 2 + 1 + im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0) + # add brighter + if np.random.rand() < 0.1 * 0.5: + im = im * (1.0 + np.random.rand() * 0.5) + im = np.clip(im, 0.0, 255.0) + # add darker + if np.random.rand() < 0.1 * 0.5: + im = im * (1.0 - np.random.rand() * 0.5) + im = np.clip(im, 0.0, 255.0) + + # Padding the im to [input_size, input_size] + new_h, new_w, _ = im.shape + if min(new_w, new_h) < input_size * 0.5: + return None + im_padded = np.ones((input_size, input_size, 3), dtype=np.float32) + im_padded[:, :, 2] = 0.485 * 255 + im_padded[:, :, 1] = 0.456 * 255 + im_padded[:, :, 0] = 0.406 * 255 + + # Random the start position + del_h = input_size - new_h + del_w = input_size - new_w + sh, sw = 0, 0 + if del_h > 1: + sh = int(np.random.rand() * del_h) + if del_w > 1: + sw = int(np.random.rand() * del_w) + + # Padding + im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy() + text_polys[:, :, 0] += sw + text_polys[:, :, 1] += sh + + score_map, score_label_map, border_map, direction_map, training_mask, \ + pos_list, pos_mask, label_list, score_label_map_text_label = self.generate_tcl_ctc_label(input_size, + input_size, + text_polys, + text_tags, + text_strs, 0.25) + if len(label_list) <= 0: # eliminate negative samples + return None + pos_list_temp = np.zeros([64, 3]) + pos_mask_temp = np.zeros([64, 1]) + label_list_temp = np.zeros([50, 1]) + 36 + + for i, label in enumerate(label_list): + n = len(label) + if n > 50: + label_list[i] = label[:50] + continue + while n < 50: + label.append([36]) + n += 1 + + for i in range(len(label_list)): + label_list[i] = np.array(label_list[i]) + + if len(pos_list) <= 0 or len(pos_list) > 30: + return None + for __ in range(30 - len(pos_list), 0, -1): + pos_list.append(pos_list_temp) + pos_mask.append(pos_mask_temp) + label_list.append(label_list_temp) + + if self.img_id == self.batch_size - 1: + self.img_id = 0 + else: + self.img_id += 1 + + im_padded[:, :, 2] -= 0.485 * 255 + im_padded[:, :, 1] -= 0.456 * 255 + im_padded[:, :, 0] -= 0.406 * 255 + im_padded[:, :, 2] /= (255.0 * 0.229) + im_padded[:, :, 1] /= (255.0 * 0.224) + im_padded[:, :, 0] /= (255.0 * 0.225) + im_padded = im_padded.transpose((2, 0, 1)) + images = im_padded[::-1, :, :] + tcl_maps = score_map[np.newaxis, :, :] + tcl_label_maps = score_label_map[np.newaxis, :, :] + border_maps = border_map.transpose((2, 0, 1)) + direction_maps = direction_map.transpose((2, 0, 1)) + training_masks = training_mask[np.newaxis, :, :] + pos_list = np.array(pos_list) + pos_mask = np.array(pos_mask) + label_list = np.array(label_list) + data['images'] = images + data['tcl_maps'] = tcl_maps + data['tcl_label_maps'] = tcl_label_maps + data['border_maps'] = border_maps + data['direction_maps'] = direction_maps + data['training_masks'] = training_masks + data['label_list'] = label_list + data['pos_list'] = pos_list + data['pos_mask'] = pos_mask + return data diff --git a/ppocr/data/pgnet_dataset.py b/ppocr/data/pgnet_dataset.py new file mode 100644 index 00000000..82c580ce --- /dev/null +++ b/ppocr/data/pgnet_dataset.py @@ -0,0 +1,164 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import os +from paddle.io import Dataset +from .imaug import transform, create_operators +import random + + +class PGDateSet(Dataset): + def __init__(self, config, mode, logger): + super(PGDateSet, self).__init__() + + self.logger = logger + global_config = config['Global'] + dataset_config = config[mode]['dataset'] + loader_config = config[mode]['loader'] + + label_file_list = dataset_config.pop('label_file_list') + data_source_num = len(label_file_list) + ratio_list = dataset_config.get("ratio_list", [1.0]) + if isinstance(ratio_list, (float, int)): + ratio_list = [float(ratio_list)] * int(data_source_num) + self.data_format = dataset_config.get('data_format', 'icdar') + assert len( + ratio_list + ) == data_source_num, "The length of ratio_list should be the same as the file_list." + # self.data_dir = dataset_config['data_dir'] + self.do_shuffle = loader_config['shuffle'] + + logger.info("Initialize indexs of datasets:%s" % label_file_list) + self.data_lines = self.get_image_info_list(label_file_list, ratio_list, + self.data_format) + self.data_idx_order_list = list(range(len(self.data_lines))) + if mode.lower() == "train": + self.shuffle_data_random() + + self.ops = create_operators(dataset_config['transforms'], global_config) + + def shuffle_data_random(self): + if self.do_shuffle: + random.shuffle(self.data_lines) + return + + def extract_polys(self, poly_txt_path): + """ + Read text_polys, txt_tags, txts from give txt file. + """ + text_polys, txt_tags, txts = [], [], [] + with open(poly_txt_path) as f: + for line in f.readlines(): + poly_str, txt = line.strip().split('\t') + poly = map(float, poly_str.split(',')) + text_polys.append( + np.array( + list(poly), dtype=np.float32).reshape(-1, 2)) + txts.append(txt) + if txt == '###': + txt_tags.append(True) + else: + txt_tags.append(False) + + return np.array(list(map(np.array, text_polys))), \ + np.array(txt_tags, dtype=np.bool), txts + + def extract_info_textnet(self, im_fn, img_dir=''): + """ + Extract information from line in textnet format. + """ + info_list = im_fn.split('\t') + img_path = '' + for ext in ['.jpg', '.png', '.jpeg', '.JPG']: + if os.path.exists(os.path.join(img_dir, info_list[0] + ext)): + img_path = os.path.join(img_dir, info_list[0] + ext) + break + + if img_path == '': + print('Image {0} NOT found in {1}, and it will be ignored.'.format( + info_list[0], img_dir)) + + nBox = (len(info_list) - 1) // 9 + wordBBs, txts, txt_tags = [], [], [] + for n in range(0, nBox): + wordBB = list(map(float, info_list[n * 9 + 1:(n + 1) * 9])) + txt = info_list[(n + 1) * 9] + wordBBs.append([[wordBB[0], wordBB[1]], [wordBB[2], wordBB[3]], + [wordBB[4], wordBB[5]], [wordBB[6], wordBB[7]]]) + txts.append(txt) + if txt == '###': + txt_tags.append(True) + else: + txt_tags.append(False) + return img_path, np.array(wordBBs, dtype=np.float32), txt_tags, txts + + def get_image_info_list(self, file_list, ratio_list, data_format='textnet'): + if isinstance(file_list, str): + file_list = [file_list] + data_lines = [] + for idx, data_source in enumerate(file_list): + image_files = [] + if data_format == 'icdar': + image_files = [ + (data_source, x) + for x in os.listdir(os.path.join(data_source, 'rgb')) + if x.split('.')[-1] in ['jpg', 'png', 'jpeg', 'JPG'] + ] + elif data_format == 'textnet': + with open(data_source) as f: + image_files = [(data_source, x.strip()) + for x in f.readlines()] + else: + print("Unrecognized data format...") + exit(-1) + image_files = random.sample( + image_files, round(len(image_files) * ratio_list[idx])) + data_lines.extend(image_files) + return data_lines + + def __getitem__(self, idx): + file_idx = self.data_idx_order_list[idx] + data_path, data_line = self.data_lines[file_idx] + try: + if self.data_format == 'icdar': + im_path = os.path.join(data_path, 'rgb', data_line) + poly_path = os.path.join(data_path, 'poly', + data_line.split('.')[0] + '.txt') + text_polys, text_tags, text_strs = self.extract_polys(poly_path) + else: + image_dir = os.path.join(os.path.dirname(data_path), 'image') + im_path, text_polys, text_tags, text_strs = self.extract_info_textnet( + data_line, image_dir) + + data = { + 'img_path': im_path, + 'polys': text_polys, + 'tags': text_tags, + 'strs': text_strs + } + with open(data['img_path'], 'rb') as f: + img = f.read() + data['image'] = img + outs = transform(data, self.ops) + except Exception as e: + self.logger.error( + "When parsing line {}, error happened with msg: {}".format( + self.data_idx_order_list[idx], e)) + outs = None + if outs is None: + return self.__getitem__(np.random.randint(self.__len__())) + return outs + + def __len__(self): + return len(self.data_idx_order_list) diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 3881abf7..223ae6b1 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -29,10 +29,11 @@ def build_loss(config): # cls loss from .cls_loss import ClsLoss + # e2e loss + from .e2e_pg_loss import PGLoss support_dict = [ 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', - 'SRNLoss' - ] + 'SRNLoss', 'PGLoss'] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/e2e_pg_loss.py b/ppocr/losses/e2e_pg_loss.py new file mode 100644 index 00000000..05480a9e --- /dev/null +++ b/ppocr/losses/e2e_pg_loss.py @@ -0,0 +1,219 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import nn +import paddle +import numpy as np +import copy + +from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss + + +class PGLoss(nn.Layer): + """ + Differentiable Binarization (DB) Loss Function + args: + param (dict): the super paramter for DB Loss + """ + + def __init__(self, alpha=5, beta=10, eps=1e-6, **kwargs): + super(PGLoss, self).__init__() + self.alpha = alpha + self.beta = beta + self.dice_loss = DiceLoss(eps=eps) + + def org_tcl_rois(self, batch_size, pos_lists, pos_masks, label_lists): + """ + """ + pos_lists_, pos_masks_, label_lists_ = [], [], [] + img_bs = batch_size + tcl_bs = 64 + ngpu = int(batch_size / img_bs) + img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy() + pos_lists_split, pos_masks_split, label_lists_split = [], [], [] + for i in range(ngpu): + pos_lists_split.append([]) + pos_masks_split.append([]) + label_lists_split.append([]) + + for i in range(img_ids.shape[0]): + img_id = img_ids[i] + gpu_id = int(img_id / img_bs) + img_id = img_id % img_bs + pos_list = pos_lists[i].copy() + pos_list[:, 0] = img_id + pos_lists_split[gpu_id].append(pos_list) + pos_masks_split[gpu_id].append(pos_masks[i].copy()) + label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i])) + # repeat or delete + for i in range(ngpu): + vp_len = len(pos_lists_split[i]) + if vp_len <= tcl_bs: + for j in range(0, tcl_bs - vp_len): + pos_list = pos_lists_split[i][j].copy() + pos_lists_split[i].append(pos_list) + pos_mask = pos_masks_split[i][j].copy() + pos_masks_split[i].append(pos_mask) + label_list = copy.deepcopy(label_lists_split[i][j]) + label_lists_split[i].append(label_list) + else: + for j in range(0, vp_len - tcl_bs): + c_len = len(pos_lists_split[i]) + pop_id = np.random.permutation(c_len)[0] + pos_lists_split[i].pop(pop_id) + pos_masks_split[i].pop(pop_id) + label_lists_split[i].pop(pop_id) + # merge + for i in range(ngpu): + pos_lists_.extend(pos_lists_split[i]) + pos_masks_.extend(pos_masks_split[i]) + label_lists_.extend(label_lists_split[i]) + return pos_lists_, pos_masks_, label_lists_ + + def pre_process(self, label_list, pos_list, pos_mask): + label_list = label_list.numpy() + b, h, w, c = label_list.shape + pos_list = pos_list.numpy() + pos_mask = pos_mask.numpy() + pos_list_t = [] + pos_mask_t = [] + label_list_t = [] + for i in range(b): + for j in range(30): + if pos_mask[i, j].any(): + pos_list_t.append(pos_list[i][j]) + pos_mask_t.append(pos_mask[i][j]) + label_list_t.append(label_list[i][j]) + pos_list, pos_mask, label_list = self.org_tcl_rois( + b, pos_list_t, pos_mask_t, label_list_t) + label = [] + tt = [l.tolist() for l in label_list] + for i in range(64): + k = 0 + for j in range(50): + if tt[i][j][0] != 36: + k += 1 + else: + break + label.append(k) + label = paddle.to_tensor(label) + label = paddle.cast(label, dtype='int64') + pos_list = paddle.to_tensor(pos_list) + pos_mask = paddle.to_tensor(pos_mask) + label_list = paddle.squeeze(paddle.to_tensor(label_list), axis=2) + label_list = paddle.cast(label_list, dtype='int32') + return pos_list, pos_mask, label_list, label + + def border_loss(self, f_border, l_border, l_score, l_mask): + l_border_split, l_border_norm = paddle.tensor.split( + l_border, num_or_sections=[4, 1], axis=1) + f_border_split = f_border + b, c, h, w = l_border_norm.shape + l_border_norm_split = paddle.expand( + x=l_border_norm, shape=[b, 4 * c, h, w]) + b, c, h, w = l_score.shape + l_border_score = paddle.expand(x=l_score, shape=[b, 4 * c, h, w]) + b, c, h, w = l_mask.shape + l_border_mask = paddle.expand(x=l_mask, shape=[b, 4 * c, h, w]) + border_diff = l_border_split - f_border_split + abs_border_diff = paddle.abs(border_diff) + border_sign = abs_border_diff < 1.0 + border_sign = paddle.cast(border_sign, dtype='float32') + border_sign.stop_gradient = True + border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \ + (abs_border_diff - 0.5) * (1.0 - border_sign) + border_out_loss = l_border_norm_split * border_in_loss + border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \ + (paddle.sum(l_border_score * l_border_mask) + 1e-5) + return border_loss + + def direction_loss(self, f_direction, l_direction, l_score, l_mask): + l_direction_split, l_direction_norm = paddle.tensor.split( + l_direction, num_or_sections=[2, 1], axis=1) + f_direction_split = f_direction + b, c, h, w = l_direction_norm.shape + l_direction_norm_split = paddle.expand( + x=l_direction_norm, shape=[b, 2 * c, h, w]) + b, c, h, w = l_score.shape + l_direction_score = paddle.expand(x=l_score, shape=[b, 2 * c, h, w]) + b, c, h, w = l_mask.shape + l_direction_mask = paddle.expand(x=l_mask, shape=[b, 2 * c, h, w]) + direction_diff = l_direction_split - f_direction_split + abs_direction_diff = paddle.abs(direction_diff) + direction_sign = abs_direction_diff < 1.0 + direction_sign = paddle.cast(direction_sign, dtype='float32') + direction_sign.stop_gradient = True + direction_in_loss = 0.5 * abs_direction_diff * abs_direction_diff * direction_sign + \ + (abs_direction_diff - 0.5) * (1.0 - direction_sign) + direction_out_loss = l_direction_norm_split * direction_in_loss + direction_loss = paddle.sum(direction_out_loss * l_direction_score * l_direction_mask) / \ + (paddle.sum(l_direction_score * l_direction_mask) + 1e-5) + return direction_loss + + def ctcloss(self, f_char, tcl_pos, tcl_mask, tcl_label, label_t): + f_char = paddle.transpose(f_char, [0, 2, 3, 1]) + tcl_pos = paddle.reshape(tcl_pos, [-1, 3]) + tcl_pos = paddle.cast(tcl_pos, dtype=int) + f_tcl_char = paddle.gather_nd(f_char, tcl_pos) + f_tcl_char = paddle.reshape(f_tcl_char, + [-1, 64, 37]) # len(Lexicon_Table)+1 + f_tcl_char_fg, f_tcl_char_bg = paddle.split(f_tcl_char, [36, 1], axis=2) + f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0 + b, c, l = tcl_mask.shape + tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, 36 * l]) + tcl_mask_fg.stop_gradient = True + f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * ( + -20.0) + f_tcl_char_mask = paddle.concat([f_tcl_char_fg, f_tcl_char_bg], axis=2) + f_tcl_char_ld = paddle.transpose(f_tcl_char_mask, (1, 0, 2)) + N, B, _ = f_tcl_char_ld.shape + input_lengths = paddle.to_tensor([N] * B, dtype='int64') + cost = paddle.nn.functional.ctc_loss( + log_probs=f_tcl_char_ld, + labels=tcl_label, + input_lengths=input_lengths, + label_lengths=label_t, + blank=36, + reduction='none') + cost = cost.mean() + return cost + + def forward(self, predicts, labels): + images, tcl_maps, tcl_label_maps, border_maps \ + , direction_maps, training_masks, label_list, pos_list, pos_mask = labels + # for all the batch_size + pos_list, pos_mask, label_list, label_t = self.pre_process( + label_list, pos_list, pos_mask) + + f_score, f_boder, f_direction, f_char = predicts + score_loss = self.dice_loss(f_score, tcl_maps, training_masks) + border_loss = self.border_loss(f_boder, border_maps, tcl_maps, + training_masks) + direction_loss = self.direction_loss(f_direction, direction_maps, + tcl_maps, training_masks) + ctc_loss = self.ctcloss(f_char, pos_list, pos_mask, label_list, label_t) + loss_all = score_loss + border_loss + direction_loss + 5 * ctc_loss + + losses = { + 'loss': loss_all, + "score_loss": score_loss, + "border_loss": border_loss, + "direction_loss": direction_loss, + "ctc_loss": ctc_loss + } + return losses diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index a0e7d912..f913010d 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -26,8 +26,9 @@ def build_metric(config): from .det_metric import DetMetric from .rec_metric import RecMetric from .cls_metric import ClsMetric + from .e2e_metric import E2EMetric - support_dict = ['DetMetric', 'RecMetric', 'ClsMetric'] + support_dict = ['DetMetric', 'RecMetric', 'ClsMetric', 'E2EMetric'] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/metrics/e2e_metric.py b/ppocr/metrics/e2e_metric.py new file mode 100644 index 00000000..6901187a --- /dev/null +++ b/ppocr/metrics/e2e_metric.py @@ -0,0 +1,87 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__all__ = ['E2EMetric'] + +from ppocr.utils.e2e_metric.Deteval import * + + +class E2EMetric(object): + def __init__(self, main_indicator='f_score_e2e', **kwargs): + self.label_list = [ + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', + 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', + 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' + ] + self.main_indicator = main_indicator + self.reset() + + def __call__(self, preds, batch, **kwargs): + ''' + batch: a list produced by dataloaders. + image: np.ndarray of shape (N, C, H, W). + ratio_list: np.ndarray of shape(N,2) + polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions. + ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not. + preds: a list of dict produced by post process + points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions. + ''' + + gt_polyons_batch = batch[2] + temp_gt_strs_batch = batch[3] + ignore_tags_batch = batch[4] + gt_strs_batch = [] + temp_gt_strs_batch = temp_gt_strs_batch[0].tolist() + for temp_list in temp_gt_strs_batch: + t = "" + for index in temp_list: + if index < 36: + t += self.label_list[index] + gt_strs_batch.append(t) + + for pred, gt_polyons, gt_strs, ignore_tags in zip( + preds, gt_polyons_batch, gt_strs_batch, ignore_tags_batch): + # prepare gt + gt_info_list = [{ + 'points': gt_polyon, + 'text': gt_str, + 'ignore': ignore_tag + } for gt_polyon, gt_str, ignore_tag in + zip(gt_polyons, gt_strs, ignore_tags)] + # prepare det + e2e_info_list = [{ + 'points': det_polyon, + 'text': pred_str + } for det_polyon, pred_str in zip(pred['points'], preds['strs'])] + result = get_socre(gt_info_list, e2e_info_list) + self.results.append(result) + + def get_metric(self): + """ + return metrics { + 'precision': 0, + 'recall': 0, + 'hmean': 0 + } + """ + metircs = combine_results(self.results) + self.reset() + return metircs + + def reset(self): + self.results = [] # clear results diff --git a/ppocr/metrics/eval_det_iou.py b/ppocr/metrics/eval_det_iou.py index a2a3f418..0e32b2d1 100644 --- a/ppocr/metrics/eval_det_iou.py +++ b/ppocr/metrics/eval_det_iou.py @@ -150,7 +150,7 @@ class DetectionIoUEvaluator(object): pairs.append({'gt': gtNum, 'det': detNum}) detMatchedNums.append(detNum) evaluationLog += "Match GT #" + \ - str(gtNum) + " with Det #" + str(detNum) + "\n" + str(gtNum) + " with Det #" + str(detNum) + "\n" numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) numDetCare = (len(detPols) - len(detDontCarePolsNum)) @@ -162,7 +162,7 @@ class DetectionIoUEvaluator(object): precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare hmean = 0 if (precision + recall) == 0 else 2.0 * \ - precision * recall / (precision + recall) + precision * recall / (precision + recall) matchedSum += detMatched numGlobalCareGt += numGtCare @@ -200,7 +200,8 @@ class DetectionIoUEvaluator(object): methodPrecision = 0 if numGlobalCareDet == 0 else float( matchedSum) / numGlobalCareDet methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \ - methodRecall * methodPrecision / (methodRecall + methodPrecision) + methodRecall * methodPrecision / ( + methodRecall + methodPrecision) # print(methodRecall, methodPrecision, methodHmean) # sys.exit(-1) methodMetrics = { diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 03c15508..fe2c9bc3 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -26,6 +26,9 @@ def build_backbone(config, model_type): from .rec_resnet_vd import ResNet from .rec_resnet_fpn import ResNetFPN support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN'] + elif model_type == 'e2e': + from .e2e_resnet_vd_pg import ResNet + support_dict = ['ResNet'] else: raise NotImplementedError diff --git a/ppocr/modeling/backbones/e2e_resnet_vd_pg.py b/ppocr/modeling/backbones/e2e_resnet_vd_pg.py new file mode 100644 index 00000000..8e3697ec --- /dev/null +++ b/ppocr/modeling/backbones/e2e_resnet_vd_pg.py @@ -0,0 +1,267 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F + +__all__ = ["ResNet"] + + +class ConvBNLayer(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + is_vd_mode=False, + act=None, + name=None, ): + super(ConvBNLayer, self).__init__() + + self.is_vd_mode = is_vd_mode + self._pool2d_avg = nn.AvgPool2D( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self._conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + self._batch_norm = nn.BatchNorm( + out_channels, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def forward(self, inputs): + # if self.is_vd_mode: + # inputs = self._pool2d_avg(inputs) + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class BottleneckBlock(nn.Layer): + def __init__(self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False, + name=None): + super(BottleneckBlock, self).__init__() + + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + act='relu', + name=name + "_branch2a") + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + act='relu', + name=name + "_branch2b") + self.conv2 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels * 4, + kernel_size=1, + act=None, + name=name + "_branch2c") + + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels * 4, + kernel_size=1, + stride=stride, + is_vd_mode=False if if_first else True, + name=name + "_branch1") + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + conv2 = self.conv2(conv1) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + y = paddle.add(x=short, y=conv2) + y = F.relu(y) + return y + + +class BasicBlock(nn.Layer): + def __init__(self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False, + name=None): + super(BasicBlock, self).__init__() + self.stride = stride + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + act='relu', + name=name + "_branch2a") + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + act=None, + name=name + "_branch2b") + + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + is_vd_mode=False if if_first else True, + name=name + "_branch1") + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + y = paddle.add(x=short, y=conv1) + y = F.relu(y) + return y + + +class ResNet(nn.Layer): + def __init__(self, in_channels=3, layers=50, **kwargs): + super(ResNet, self).__init__() + + self.layers = layers + supported_layers = [18, 34, 50, 101, 152, 200] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format( + supported_layers, layers) + + if layers == 18: + depth = [2, 2, 2, 2] + elif layers == 34 or layers == 50: + # depth = [3, 4, 6, 3] + depth = [3, 4, 6, 3, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + elif layers == 200: + depth = [3, 12, 48, 3] + num_channels = [64, 256, 512, 1024, + 2048] if layers >= 50 else [64, 64, 128, 256] + num_filters = [64, 128, 256, 512, 512] + + self.conv1_1 = ConvBNLayer( + in_channels=in_channels, + out_channels=64, + kernel_size=7, + stride=2, + act='relu', + name="conv1_1") + self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + + self.stages = [] + self.out_channels = [3, 64] + # num_filters = [64, 128, 256, 512, 512] + if layers >= 50: + for block in range(len(depth)): + block_list = [] + shortcut = False + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + bottleneck_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BottleneckBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block] * 4, + out_channels=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut, + if_first=block == i == 0, + name=conv_name)) + shortcut = True + block_list.append(bottleneck_block) + self.out_channels.append(num_filters[block] * 4) + self.stages.append(nn.Sequential(*block_list)) + else: + for block in range(len(depth)): + block_list = [] + shortcut = False + for i in range(depth[block]): + conv_name = "res" + str(block + 2) + chr(97 + i) + basic_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BasicBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block], + out_channels=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut, + if_first=block == i == 0, + name=conv_name)) + shortcut = True + block_list.append(basic_block) + self.out_channels.append(num_filters[block]) + self.stages.append(nn.Sequential(*block_list)) + + def forward(self, inputs): + out = [inputs] + y = self.conv1_1(inputs) + out.append(y) + y = self.pool2d_max(y) + for block in self.stages: + y = block(y) + out.append(y) + return out diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index efe05718..4852c7f2 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -20,6 +20,7 @@ def build_head(config): from .det_db_head import DBHead from .det_east_head import EASTHead from .det_sast_head import SASTHead + from .e2e_pg_head import PGHead # rec head from .rec_ctc_head import CTCHead @@ -30,8 +31,8 @@ def build_head(config): from .cls_head import ClsHead support_dict = [ 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', - 'SRNHead' - ] + 'SRNHead', 'PGHead'] + module_name = config.pop('name') assert module_name in support_dict, Exception('head only support {}'.format( diff --git a/ppocr/modeling/heads/e2e_pg_head.py b/ppocr/modeling/heads/e2e_pg_head.py new file mode 100644 index 00000000..41ead8e8 --- /dev/null +++ b/ppocr/modeling/heads/e2e_pg_head.py @@ -0,0 +1,249 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups=1, + if_act=True, + act=None, + name=None): + super(ConvBNLayer, self).__init__() + self.if_act = if_act + self.act = act + self.conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + weight_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + + self.bn = nn.BatchNorm( + num_channels=out_channels, + act=act, + param_attr=ParamAttr(name="bn_" + name + "_scale"), + bias_attr=ParamAttr(name="bn_" + name + "_offset"), + moving_mean_name="bn_" + name + "_mean", + moving_variance_name="bn_" + name + "_variance", + use_global_stats=False) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class PGHead(nn.Layer): + """ + """ + + def __init__(self, in_channels, model_name, **kwargs): + super(PGHead, self).__init__() + self.model_name = model_name + self.conv_f_score1 = ConvBNLayer( + in_channels=in_channels, + out_channels=64, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_score{}".format(1)) + self.conv_f_score2 = ConvBNLayer( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + act='relu', + name="conv_f_score{}".format(2)) + self.conv_f_score3 = ConvBNLayer( + in_channels=64, + out_channels=128, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_score{}".format(3)) + + self.conv1 = nn.Conv2D( + in_channels=128, + out_channels=1, + kernel_size=3, + stride=1, + padding=1, + groups=1, + weight_attr=ParamAttr(name="conv_f_score{}".format(4)), + bias_attr=False) + + self.conv_f_boder1 = ConvBNLayer( + in_channels=in_channels, + out_channels=64, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_boder{}".format(1)) + self.conv_f_boder2 = ConvBNLayer( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + act='relu', + name="conv_f_boder{}".format(2)) + self.conv_f_boder3 = ConvBNLayer( + in_channels=64, + out_channels=128, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_boder{}".format(3)) + self.conv2 = nn.Conv2D( + in_channels=128, + out_channels=4, + kernel_size=3, + stride=1, + padding=1, + groups=1, + weight_attr=ParamAttr(name="conv_f_boder{}".format(4)), + bias_attr=False) + self.conv_f_char1 = ConvBNLayer( + in_channels=in_channels, + out_channels=128, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_char{}".format(1)) + self.conv_f_char2 = ConvBNLayer( + in_channels=128, + out_channels=128, + kernel_size=3, + stride=1, + padding=1, + act='relu', + name="conv_f_char{}".format(2)) + self.conv_f_char3 = ConvBNLayer( + in_channels=128, + out_channels=256, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_char{}".format(3)) + self.conv_f_char4 = ConvBNLayer( + in_channels=256, + out_channels=256, + kernel_size=3, + stride=1, + padding=1, + act='relu', + name="conv_f_char{}".format(4)) + self.conv_f_char5 = ConvBNLayer( + in_channels=256, + out_channels=256, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_char{}".format(5)) + self.conv3 = nn.Conv2D( + in_channels=256, + out_channels=6625, + kernel_size=3, + stride=1, + padding=1, + groups=1, + weight_attr=ParamAttr(name="conv_f_char{}".format(6)), + bias_attr=False) + + self.conv_f_direc1 = ConvBNLayer( + in_channels=in_channels, + out_channels=64, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_direc{}".format(1)) + self.conv_f_direc2 = ConvBNLayer( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + act='relu', + name="conv_f_direc{}".format(2)) + self.conv_f_direc3 = ConvBNLayer( + in_channels=64, + out_channels=128, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_direc{}".format(3)) + self.conv4 = nn.Conv2D( + in_channels=128, + out_channels=2, + kernel_size=3, + stride=1, + padding=1, + groups=1, + weight_attr=ParamAttr(name="conv_f_direc{}".format(4)), + bias_attr=False) + + def forward(self, x): + f_score = self.conv_f_score1(x) + f_score = self.conv_f_score2(f_score) + f_score = self.conv_f_score3(f_score) + f_score = self.conv1(f_score) + f_score = F.sigmoid(f_score) + + # f_boder + f_boder = self.conv_f_boder1(x) + f_boder = self.conv_f_boder2(f_boder) + f_boder = self.conv_f_boder3(f_boder) + f_boder = self.conv2(f_boder) + + f_char = self.conv_f_char1(x) + f_char = self.conv_f_char2(f_char) + f_char = self.conv_f_char3(f_char) + f_char = self.conv_f_char4(f_char) + f_char = self.conv_f_char5(f_char) + f_char = self.conv3(f_char) + + f_direction = self.conv_f_direc1(x) + f_direction = self.conv_f_direc2(f_direction) + f_direction = self.conv_f_direc3(f_direction) + f_direction = self.conv4(f_direction) + + return f_score, f_boder, f_direction, f_char diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py index 405e062b..37a5cf78 100644 --- a/ppocr/modeling/necks/__init__.py +++ b/ppocr/modeling/necks/__init__.py @@ -14,12 +14,14 @@ __all__ = ['build_neck'] + def build_neck(config): from .db_fpn import DBFPN from .east_fpn import EASTFPN from .sast_fpn import SASTFPN from .rnn import SequenceEncoder - support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder'] + from .pg_fpn import PGFPN + support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN'] module_name = config.pop('name') assert module_name in support_dict, Exception('neck only support {}'.format( diff --git a/ppocr/modeling/necks/pg_fpn.py b/ppocr/modeling/necks/pg_fpn.py new file mode 100644 index 00000000..9bd560c9 --- /dev/null +++ b/ppocr/modeling/necks/pg_fpn.py @@ -0,0 +1,277 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + is_vd_mode=False, + act=None, + name=None): + super(ConvBNLayer, self).__init__() + + self.is_vd_mode = is_vd_mode + self._pool2d_avg = nn.AvgPool2D( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self._conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + self._batch_norm = nn.BatchNorm( + out_channels, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', + use_global_stats=False) + + def forward(self, inputs): + # if self.is_vd_mode: + # inputs = self._pool2d_avg(inputs) + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class DeConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size=4, + stride=2, + padding=1, + groups=1, + if_act=True, + act=None, + name=None): + super(DeConvBNLayer, self).__init__() + + self.if_act = if_act + self.act = act + self.deconv = nn.Conv2DTranspose( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + weight_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + self.bn = nn.BatchNorm( + num_channels=out_channels, + act=act, + param_attr=ParamAttr(name="bn_" + name + "_scale"), + bias_attr=ParamAttr(name="bn_" + name + "_offset"), + moving_mean_name="bn_" + name + "_mean", + moving_variance_name="bn_" + name + "_variance", + use_global_stats=False) + + def forward(self, x): + x = self.deconv(x) + x = self.bn(x) + return x + + +class FPN_Up_Fusion(nn.Layer): + def __init__(self, in_channels): + super(FPN_Up_Fusion, self).__init__() + in_channels = in_channels[::-1] + out_channels = [256, 256, 192, 192, 128] + + self.h0_conv = ConvBNLayer( + in_channels[0], out_channels[0], 1, 1, act=None, name='conv_h0') + self.h1_conv = ConvBNLayer( + in_channels[1], out_channels[1], 1, 1, act=None, name='conv_h1') + self.h2_conv = ConvBNLayer( + in_channels[2], out_channels[2], 1, 1, act=None, name='conv_h2') + self.h3_conv = ConvBNLayer( + in_channels[3], out_channels[3], 1, 1, act=None, name='conv_h3') + self.h4_conv = ConvBNLayer( + in_channels[4], out_channels[4], 1, 1, act=None, name='conv_h4') + + self.dconv0 = DeConvBNLayer( + in_channels=out_channels[0], + out_channels=out_channels[1], + name="dconv_{}".format(0)) + self.dconv1 = DeConvBNLayer( + in_channels=out_channels[1], + out_channels=out_channels[2], + act=None, + name="dconv_{}".format(1)) + self.dconv2 = DeConvBNLayer( + in_channels=out_channels[2], + out_channels=out_channels[3], + act=None, + name="dconv_{}".format(2)) + self.dconv3 = DeConvBNLayer( + in_channels=out_channels[3], + out_channels=out_channels[4], + act=None, + name="dconv_{}".format(3)) + self.conv_g1 = ConvBNLayer( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=3, + stride=1, + act='relu', + name="conv_g{}".format(1)) + self.conv_g2 = ConvBNLayer( + in_channels=out_channels[2], + out_channels=out_channels[2], + kernel_size=3, + stride=1, + act='relu', + name="conv_g{}".format(2)) + self.conv_g3 = ConvBNLayer( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=1, + act='relu', + name="conv_g{}".format(3)) + self.conv_g4 = ConvBNLayer( + in_channels=out_channels[4], + out_channels=out_channels[4], + kernel_size=3, + stride=1, + act='relu', + name="conv_g{}".format(4)) + self.convf = ConvBNLayer( + in_channels=out_channels[4], + out_channels=out_channels[4], + kernel_size=1, + stride=1, + act=None, + name="conv_f{}".format(4)) + + def _add_relu(self, x1, x2): + x = paddle.add(x=x1, y=x2) + x = F.relu(x) + return x + + def forward(self, x): + f = x[2:][::-1] + h0 = self.h0_conv(f[0]) + h1 = self.h1_conv(f[1]) + h2 = self.h2_conv(f[2]) + h3 = self.h3_conv(f[3]) + h4 = self.h4_conv(f[4]) + + g0 = self.dconv0(h0) + + g1 = self.dconv2(self.conv_g2(self._add_relu(g0, h1))) + g2 = self.dconv2(self.conv_g2(self._add_relu(g1, h2))) + g3 = self.dconv3(self.conv_g2(self._add_relu(g2, h3))) + g4 = self.dconv4(self.conv_g2(self._add_relu(g3, h4))) + return g4 + + +class FPN_Down_Fusion(nn.Layer): + def __init__(self, in_channels): + super(FPN_Down_Fusion, self).__init__() + out_channels = [32, 64, 128] + + self.h0_conv = ConvBNLayer( + in_channels[0], out_channels[0], 3, 1, act=None, name='FPN_d1') + self.h1_conv = ConvBNLayer( + in_channels[1], out_channels[1], 3, 1, act=None, name='FPN_d2') + self.h2_conv = ConvBNLayer( + in_channels[2], out_channels[2], 3, 1, act=None, name='FPN_d3') + + self.g0_conv = ConvBNLayer( + out_channels[0], out_channels[1], 3, 2, act=None, name='FPN_d4') + + self.g1_conv = nn.Sequential( + ConvBNLayer( + out_channels[1], + out_channels[1], + 3, + 1, + act='relu', + name='FPN_d5'), + ConvBNLayer( + out_channels[1], out_channels[2], 3, 2, act=None, + name='FPN_d6')) + + self.g2_conv = nn.Sequential( + ConvBNLayer( + out_channels[2], + out_channels[2], + 3, + 1, + act='relu', + name='FPN_d7'), + ConvBNLayer( + out_channels[2], out_channels[2], 1, 1, act=None, + name='FPN_d8')) + + def forward(self, x): + f = x[:3] + h0 = self.h0_conv(f[0]) + h1 = self.h1_conv(f[1]) + h2 = self.h2_conv(f[2]) + g0 = self.g0_conv(h0) + g1 = paddle.add(x=g0, y=h1) + g1 = F.relu(g1) + g1 = self.g1_conv(g1) + g2 = paddle.add(x=g1, y=h2) + g2 = F.relu(g2) + g2 = self.g2_conv(g2) + return g2 + + +class PGFPN(nn.Layer): + def __init__(self, in_channels, with_cab=False, **kwargs): + super(PGFPN, self).__init__() + self.in_channels = in_channels + self.with_cab = with_cab + self.FPN_Down_Fusion = FPN_Down_Fusion(self.in_channels) + self.FPN_Up_Fusion = FPN_Up_Fusion(self.in_channels) + self.out_channels = 128 + + def forward(self, x): + # down fpn + f_down = self.FPN_Down_Fusion(x) + + # up fpn + f_up = self.FPN_Up_Fusion(x) + + # fusion + f_common = paddle.add(x=f_down, y=f_up) + f_common = F.relu(f_common) + + return f_common diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 0156e438..042654a1 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -28,10 +28,11 @@ def build_post_process(config, global_config=None): from .sast_postprocess import SASTPostProcess from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode from .cls_postprocess import ClsPostProcess + from .pg_postprocess import PGPostProcess support_dict = [ 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', - 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode' + 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/pg_postprocess.py b/ppocr/postprocess/pg_postprocess.py new file mode 100644 index 00000000..90031a83 --- /dev/null +++ b/ppocr/postprocess/pg_postprocess.py @@ -0,0 +1,194 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +__dir__ = os.path.dirname(__file__) +sys.path.append(__dir__) +sys.path.append(os.path.join(__dir__, '..')) + +import numpy as np +from .locality_aware_nms import nms_locality +from ppocr.utils.e2e_utils.extract_textpoint import * +from ppocr.utils.e2e_utils.ski_thin import * +from ppocr.utils.e2e_utils.visual import * +import paddle +import cv2 +import time + + +class PGPostProcess(object): + """ + The post process for SAST. + """ + + def __init__(self, + score_thresh=0.5, + nms_thresh=0.2, + sample_pts_num=2, + shrink_ratio_of_width=0.3, + expand_scale=1.0, + tcl_map_thresh=0.5, + **kwargs): + self.result_path = "" + self.valid_set = 'totaltext' + self.Lexicon_Table = [ + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', + 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', + 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' + ] + self.score_thresh = score_thresh + self.nms_thresh = nms_thresh + self.sample_pts_num = sample_pts_num + self.shrink_ratio_of_width = shrink_ratio_of_width + self.expand_scale = expand_scale + self.tcl_map_thresh = tcl_map_thresh + + # c++ la-nms is faster, but only support python 3.5 + self.is_python35 = False + if sys.version_info.major == 3 and sys.version_info.minor == 5: + self.is_python35 = True + + def __call__(self, outs_dict, shape_list): + p_score, p_border, p_direction, p_char = outs_dict[:4] + p_score = p_score[0].numpy() + p_border = p_border[0].numpy() + p_direction = p_direction[0].numpy() + p_char = p_char[0].numpy() + src_h, src_w, ratio_h, ratio_w = shape_list[0] + if self.valid_set != 'totaltext': + is_curved = False + else: + is_curved = True + instance_yxs_list = generate_pivot_list( + p_score, + p_char, + p_direction, + score_thresh=self.score_thresh, + is_backbone=True, + is_curved=is_curved) + p_char = np.expand_dims(p_char, axis=0) + p_char = paddle.to_tensor(p_char) + char_seq_idx_set = [] + for i in range(len(instance_yxs_list)): + gather_info_lod = paddle.to_tensor(instance_yxs_list[i]) + f_char_map = paddle.transpose(p_char, [0, 2, 3, 1]) + featyre_seq = paddle.gather_nd(f_char_map, gather_info_lod) + featyre_seq = np.expand_dims(featyre_seq.numpy(), axis=0) + t = len(featyre_seq[0]) + featyre_seq = paddle.to_tensor(featyre_seq) + l = np.array([[t]]).astype(np.int64) + length = paddle.to_tensor(l) + seq_pred = paddle.fluid.layers.ctc_greedy_decoder( + input=featyre_seq, blank=36, input_length=length) + seq_pred1 = seq_pred[0].numpy().tolist()[0] + seq_len = seq_pred[1].numpy()[0][0] + temp_t = [] + for x in seq_pred1[:seq_len]: + temp_t.append(x) + char_seq_idx_set.append(temp_t) + seq_strs = [] + for char_idx_set in char_seq_idx_set: + pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set]) + seq_strs.append(pr_str) + poly_list = [] + keep_str_list = [] + all_point_list = [] + all_point_pair_list = [] + for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs): + if len(yx_center_line) == 1: + print('the length of tcl point is less than 2, repeat') + yx_center_line.append(yx_center_line[-1]) + + # expand corresponding offset for total-text. + offset_expand = 1.0 + if self.valid_set == 'totaltext': + offset_expand = 1.2 + + point_pair_list = [] + for batch_id, y, x in yx_center_line: + offset = p_border[:, y, x].reshape(2, 2) + if offset_expand != 1.0: + offset_length = np.linalg.norm( + offset, axis=1, keepdims=True) + expand_length = np.clip( + offset_length * (offset_expand - 1), + a_min=0.5, + a_max=3.0) + offset_detal = offset / offset_length * expand_length + offset = offset + offset_detal + ori_yx = np.array([y, x], dtype=np.float32) + point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array( + [ratio_w, ratio_h]).reshape(-1, 2) + point_pair_list.append(point_pair) + + # for visualization + all_point_list.append([ + int(round(x * 4.0 / ratio_w)), + int(round(y * 4.0 / ratio_h)) + ]) + all_point_pair_list.append(point_pair.round().astype(np.int32) + .tolist()) + + # ndarry: (x, 2) + detected_poly, pair_length_info = point_pair2poly(point_pair_list) + print('expand along width. {}'.format(detected_poly.shape)) + detected_poly = expand_poly_along_width( + detected_poly, shrink_ratio_of_width=0.2) + detected_poly[:, 0] = np.clip( + detected_poly[:, 0], a_min=0, a_max=src_w) + detected_poly[:, 1] = np.clip( + detected_poly[:, 1], a_min=0, a_max=src_h) + + if len(keep_str) < 2: + print('--> too short, {}'.format(keep_str)) + continue + + keep_str_list.append(keep_str) + if self.valid_set == 'partvgg': + middle_point = len(detected_poly) // 2 + detected_poly = detected_poly[ + [0, middle_point - 1, middle_point, -1], :] + poly_list.append(detected_poly) + elif self.valid_set == 'totaltext': + poly_list.append(detected_poly) + else: + print('--> Not supported format.') + exit(-1) + data = { + 'points': poly_list, + 'strs': keep_str_list, + } + # visualization + # if self.save_visualization: + # visualize_e2e_result(im_fn, poly_list, keep_str_list, src_im) + # visualize_point_result(im_fn, all_point_list, all_point_pair_list, src_im) + + # save detected boxes + # txt_dir = (result_path[:-1] if result_path.endswith('/') else result_path) + '_txt_anno' + # if not os.path.exists(txt_dir): + # os.makedirs(txt_dir) + # res_file = os.path.join(txt_dir, '{}.txt'.format(im_prefix)) + # with open(res_file, 'w') as f: + # for i_box, box in enumerate(poly_list): + # seq_str = keep_str_list[i_box] + # box = np.round(box).astype('int32') + # box_str = ','.join(str(s) for s in (box.flatten().tolist())) + # f.write('{}\t{}\r\n'.format(box_str, seq_str)) + return data diff --git a/ppocr/postprocess/sast_postprocess.py b/ppocr/postprocess/sast_postprocess.py index f011e7e5..bee75c05 100755 --- a/ppocr/postprocess/sast_postprocess.py +++ b/ppocr/postprocess/sast_postprocess.py @@ -18,6 +18,7 @@ from __future__ import print_function import os import sys + __dir__ = os.path.dirname(__file__) sys.path.append(__dir__) sys.path.append(os.path.join(__dir__, '..')) @@ -49,12 +50,12 @@ class SASTPostProcess(object): self.shrink_ratio_of_width = shrink_ratio_of_width self.expand_scale = expand_scale self.tcl_map_thresh = tcl_map_thresh - + # c++ la-nms is faster, but only support python 3.5 self.is_python35 = False if sys.version_info.major == 3 and sys.version_info.minor == 5: self.is_python35 = True - + def point_pair2poly(self, point_pair_list): """ Transfer vertical point_pairs into poly point in clockwise. @@ -66,31 +67,42 @@ class SASTPostProcess(object): point_list[idx] = point_pair[0] point_list[point_num - 1 - idx] = point_pair[1] return np.array(point_list).reshape(-1, 2) - - def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.): + + def shrink_quad_along_width(self, + quad, + begin_width_ratio=0., + end_width_ratio=1.): """ Generate shrink_quad_along_width. """ - ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + ratio_pair = np.array( + [[begin_width_ratio], [end_width_ratio]], dtype=np.float32) p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) - + def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3): """ expand poly along width. """ point_num = poly.shape[0] - left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) + left_quad = np.array( + [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \ - (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) - left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0) - right_quad = np.array([poly[point_num // 2 - 2], poly[point_num // 2 - 1], - poly[point_num // 2], poly[point_num // 2 + 1]], dtype=np.float32) + (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) + left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, + 1.0) + right_quad = np.array( + [ + poly[point_num // 2 - 2], poly[point_num // 2 - 1], + poly[point_num // 2], poly[point_num // 2 + 1] + ], + dtype=np.float32) right_ratio = 1.0 + \ - shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ - (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) - right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio) + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ + (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) + right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, + right_ratio) poly[0] = left_quad_expand[0] poly[-1] = left_quad_expand[-1] poly[point_num // 2 - 1] = right_quad_expand[1] @@ -100,7 +112,7 @@ class SASTPostProcess(object): def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map): """Restore quad.""" xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh) - xy_text = xy_text[:, ::-1] # (n, 2) + xy_text = xy_text[:, ::-1] # (n, 2) # Sort the text boxes via the y axis xy_text = xy_text[np.argsort(xy_text[:, 1])] @@ -112,7 +124,7 @@ class SASTPostProcess(object): point_num = int(tvo_map.shape[-1] / 2) assert point_num == 4 tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :] - xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2) + xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2) quads = xy_text_tile - tvo_map return scores, quads, xy_text @@ -121,14 +133,12 @@ class SASTPostProcess(object): """ compute area of a quad. """ - edge = [ - (quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]), - (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]), - (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]), - (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1]) - ] + edge = [(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]), + (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]), + (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]), + (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])] return np.sum(edge) / 2. - + def nms(self, dets): if self.is_python35: import lanms @@ -141,7 +151,7 @@ class SASTPostProcess(object): """ Cluster pixels in tcl_map based on quads. """ - instance_count = quads.shape[0] + 1 # contain background + instance_count = quads.shape[0] + 1 # contain background instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32) if instance_count == 1: return instance_count, instance_label_map @@ -149,18 +159,19 @@ class SASTPostProcess(object): # predict text center xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh) n = xy_text.shape[0] - xy_text = xy_text[:, ::-1] # (n, 2) - tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2) + xy_text = xy_text[:, ::-1] # (n, 2) + tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2) pred_tc = xy_text - tco - + # get gt text center m = quads.shape[0] - gt_tc = np.mean(quads, axis=1) # (m, 2) + gt_tc = np.mean(quads, axis=1) # (m, 2) - pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2) - gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2) - dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m) - xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,) + pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], + (1, m, 1)) # (n, m, 2) + gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2) + dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m) + xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,) instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign return instance_count, instance_label_map @@ -169,26 +180,47 @@ class SASTPostProcess(object): """ Estimate sample points number. """ - eh = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) / 2.0 - ew = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0 + eh = (np.linalg.norm(quad[0] - quad[3]) + + np.linalg.norm(quad[1] - quad[2])) / 2.0 + ew = (np.linalg.norm(quad[0] - quad[1]) + + np.linalg.norm(quad[2] - quad[3])) / 2.0 dense_sample_pts_num = max(2, int(ew)) - dense_xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, dense_sample_pts_num, - endpoint=True, dtype=np.float32).astype(np.int32)] - - dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1] - estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1)) + dense_xy_center_line = xy_text[np.linspace( + 0, + xy_text.shape[0] - 1, + dense_sample_pts_num, + endpoint=True, + dtype=np.float32).astype(np.int32)] + + dense_xy_center_line_diff = dense_xy_center_line[ + 1:] - dense_xy_center_line[:-1] + estimate_arc_len = np.sum( + np.linalg.norm( + dense_xy_center_line_diff, axis=1)) sample_pts_num = max(2, int(estimate_arc_len / eh)) return sample_pts_num - def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_w, src_h, - shrink_ratio_of_width=0.3, tcl_map_thresh=0.5, offset_expand=1.0, out_strid=4.0): + def detect_sast(self, + tcl_map, + tvo_map, + tbo_map, + tco_map, + ratio_w, + ratio_h, + src_w, + src_h, + shrink_ratio_of_width=0.3, + tcl_map_thresh=0.5, + offset_expand=1.0, + out_strid=4.0): """ first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys """ # restore quad - scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map) + scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, + tvo_map) dets = np.hstack((quads, scores)).astype(np.float32, copy=False) dets = self.nms(dets) if dets.shape[0] == 0: @@ -202,7 +234,8 @@ class SASTPostProcess(object): # instance segmentation # instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8) - instance_count, instance_label_map = self.cluster_by_quads_tco(tcl_map, tcl_map_thresh, quads, tco_map) + instance_count, instance_label_map = self.cluster_by_quads_tco( + tcl_map, tcl_map_thresh, quads, tco_map) # restore single poly with tcl instance. poly_list = [] @@ -212,10 +245,10 @@ class SASTPostProcess(object): q_area = quad_areas[instance_idx - 1] if q_area < 5: continue - + # - len1 = float(np.linalg.norm(quad[0] -quad[1])) - len2 = float(np.linalg.norm(quad[1] -quad[2])) + len1 = float(np.linalg.norm(quad[0] - quad[1])) + len2 = float(np.linalg.norm(quad[1] - quad[2])) min_len = min(len1, len2) if min_len < 3: continue @@ -225,16 +258,18 @@ class SASTPostProcess(object): continue # filter low confidence instance - xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0] + xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0] if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1: - # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05: + # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05: continue # sort xy_text - left_center_pt = np.array([[(quad[0, 0] + quad[-1, 0]) / 2.0, - (quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2) - right_center_pt = np.array([[(quad[1, 0] + quad[2, 0]) / 2.0, - (quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2) + left_center_pt = np.array( + [[(quad[0, 0] + quad[-1, 0]) / 2.0, + (quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2) + right_center_pt = np.array( + [[(quad[1, 0] + quad[2, 0]) / 2.0, + (quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2) proj_unit_vec = (right_center_pt - left_center_pt) / \ (np.linalg.norm(right_center_pt - left_center_pt) + 1e-6) proj_value = np.sum(xy_text * proj_unit_vec, axis=1) @@ -245,33 +280,45 @@ class SASTPostProcess(object): sample_pts_num = self.estimate_sample_pts_num(quad, xy_text) else: sample_pts_num = self.sample_pts_num - xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, sample_pts_num, - endpoint=True, dtype=np.float32).astype(np.int32)] + xy_center_line = xy_text[np.linspace( + 0, + xy_text.shape[0] - 1, + sample_pts_num, + endpoint=True, + dtype=np.float32).astype(np.int32)] point_pair_list = [] for x, y in xy_center_line: # get corresponding offset offset = tbo_map[y, x, :].reshape(2, 2) if offset_expand != 1.0: - offset_length = np.linalg.norm(offset, axis=1, keepdims=True) - expand_length = np.clip(offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0) + offset_length = np.linalg.norm( + offset, axis=1, keepdims=True) + expand_length = np.clip( + offset_length * (offset_expand - 1), + a_min=0.5, + a_max=3.0) offset_detal = offset / offset_length * expand_length - offset = offset + offset_detal - # original point + offset = offset + offset_detal + # original point ori_yx = np.array([y, x], dtype=np.float32) - point_pair = (ori_yx + offset)[:, ::-1]* out_strid / np.array([ratio_w, ratio_h]).reshape(-1, 2) + point_pair = (ori_yx + offset)[:, ::-1] * out_strid / np.array( + [ratio_w, ratio_h]).reshape(-1, 2) point_pair_list.append(point_pair) # ndarry: (x, 2), expand poly along width detected_poly = self.point_pair2poly(point_pair_list) - detected_poly = self.expand_poly_along_width(detected_poly, shrink_ratio_of_width) - detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w) - detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h) + detected_poly = self.expand_poly_along_width(detected_poly, + shrink_ratio_of_width) + detected_poly[:, 0] = np.clip( + detected_poly[:, 0], a_min=0, a_max=src_w) + detected_poly[:, 1] = np.clip( + detected_poly[:, 1], a_min=0, a_max=src_h) poly_list.append(detected_poly) return poly_list - def __call__(self, outs_dict, shape_list): + def __call__(self, outs_dict, shape_list): score_list = outs_dict['f_score'] border_list = outs_dict['f_border'] tvo_list = outs_dict['f_tvo'] @@ -281,20 +328,28 @@ class SASTPostProcess(object): border_list = border_list.numpy() tvo_list = tvo_list.numpy() tco_list = tco_list.numpy() - + img_num = len(shape_list) poly_lists = [] for ino in range(img_num): - p_score = score_list[ino].transpose((1,2,0)) - p_border = border_list[ino].transpose((1,2,0)) - p_tvo = tvo_list[ino].transpose((1,2,0)) - p_tco = tco_list[ino].transpose((1,2,0)) + p_score = score_list[ino].transpose((1, 2, 0)) + p_border = border_list[ino].transpose((1, 2, 0)) + p_tvo = tvo_list[ino].transpose((1, 2, 0)) + p_tco = tco_list[ino].transpose((1, 2, 0)) src_h, src_w, ratio_h, ratio_w = shape_list[ino] - poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h, - shrink_ratio_of_width=self.shrink_ratio_of_width, - tcl_map_thresh=self.tcl_map_thresh, offset_expand=self.expand_scale) + poly_list = self.detect_sast( + p_score, + p_tvo, + p_border, + p_tco, + ratio_w, + ratio_h, + src_w, + src_h, + shrink_ratio_of_width=self.shrink_ratio_of_width, + tcl_map_thresh=self.tcl_map_thresh, + offset_expand=self.expand_scale) poly_lists.append({'points': np.array(poly_list)}) return poly_lists - diff --git a/ppocr/utils/e2e_metric/Deteval.py b/ppocr/utils/e2e_metric/Deteval.py new file mode 100755 index 00000000..fd12ecab --- /dev/null +++ b/ppocr/utils/e2e_metric/Deteval.py @@ -0,0 +1,877 @@ +from os import listdir +import os, sys +from scipy import io +import numpy as np +from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area +from tqdm import tqdm + +try: # python2 + range = xrange +except Exception: + # python3 + range = range +""" +Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('\n')' +""" + +# if len(sys.argv) != 4: +# print('\n usage: test.py pred_dir gt_dir savefile') +# sys.exit() + + +def get_socre(gt_dict, pred_dict): + # allInputs = listdir(input_dir) + allInputs = 1 + + def input_reading_mod(pred_dict, input): + """This helper reads input from txt files""" + det = [] + n = len(pred_dict) + for i in range(n): + points = pred_dict[i]['points'] + text = pred_dict[i]['text'] + # for i in range(len(points)): + point = ",".join(map(str, points.reshape(-1, ))) + det.append([point, text]) + return det + + def gt_reading_mod(gt_dict, gt_id): + """This helper reads groundtruths from mat files""" + # gt_id = gt_id.split('.')[0] + gt = [] + n = len(gt_dict) + for i in range(n): + points = gt_dict[i]['points'].tolist() + h = len(points) + text = gt_dict[i]['text'] + xx = [ + np.array( + ['x:'], dtype=' 1): + gt_x = list(map(int, np.squeeze(gt[1]))) + gt_y = list(map(int, np.squeeze(gt[3]))) + for det_id, detection in enumerate(detections): + detection_orig = detection + detection = [float(x) for x in detection[0].split(',')] + # detection = detection.split(',') + detection = list(map(int, detection)) + det_x = detection[0::2] + det_y = detection[1::2] + det_gt_iou = iod(det_x, det_y, gt_x, gt_y) + if det_gt_iou > threshold: + detections[det_id] = [] + + detections[:] = [item for item in detections if item != []] + return detections + + def sigma_calculation(det_x, det_y, gt_x, gt_y): + """ + sigma = inter_area / gt_area + """ + # print(area_of_intersection(det_x, det_y, gt_x, gt_y)) + return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / + area(gt_x, gt_y)), 2) + + def tau_calculation(det_x, det_y, gt_x, gt_y): + """ + tau = inter_area / det_area + """ + # print "liushanshan det_x {}".format(det_x) + # print "liushanshan det_y {}".format(det_y) + # print "liushanshan area {}".format(area(det_x, det_y)) + # print "liushanshan tau = {}".format(np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / area(det_x, det_y)), 2)) + if area(det_x, det_y) == 0.0: + return 0 + return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / + area(det_x, det_y)), 2) + + ##############################Initialization################################### + global_tp = 0 + global_fp = 0 + global_fn = 0 + global_sigma = [] + global_tau = [] + tr = 0.7 + tp = 0.6 + fsc_k = 0.8 + k = 2 + global_pred_str = [] + global_gt_str = [] + ############################################################################### + + for input_id in range(allInputs): + if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and ( + input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and ( + input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \ + and (input_id != 'Deteval_result_non_curved.txt'): + print(input_id) + detections = input_reading_mod(pred_dict, input_id) + # print "liushanshan detections = {}".format(detections) + groundtruths = gt_reading_mod(gt_dict, input_id) + detections = detection_filtering( + detections, + groundtruths) # filters detections overlapping with DC area + dc_id = [] + for i in range(len(groundtruths)): + if groundtruths[i][5] == '#': + dc_id.append(i) + cnt = 0 + for a in dc_id: + num = a - cnt + del groundtruths[num] + cnt += 1 + + local_sigma_table = np.zeros((len(groundtruths), len(detections))) + local_tau_table = np.zeros((len(groundtruths), len(detections))) + local_pred_str = {} + local_gt_str = {} + + for gt_id, gt in enumerate(groundtruths): + if len(detections) > 0: + for det_id, detection in enumerate(detections): + detection_orig = detection + detection = [float(x) for x in detection[0].split(',')] + detection = list(map(int, detection)) + pred_seq_str = detection_orig[1].strip() + det_x = detection[0::2] + det_y = detection[1::2] + gt_x = list(map(int, np.squeeze(gt[1]))) + gt_y = list(map(int, np.squeeze(gt[3]))) + gt_seq_str = str(gt[4].tolist()[0]) + + local_sigma_table[gt_id, det_id] = sigma_calculation( + det_x, det_y, gt_x, gt_y) + local_tau_table[gt_id, det_id] = tau_calculation( + det_x, det_y, gt_x, gt_y) + local_pred_str[det_id] = pred_seq_str + local_gt_str[gt_id] = gt_seq_str + + global_sigma.append(local_sigma_table) + global_tau.append(local_tau_table) + global_pred_str.append(local_pred_str) + global_gt_str.append(local_gt_str) + print + "liushanshan global_pred_str = {}".format(global_pred_str) + print + "liushanshan global_gt_str = {}".format(global_gt_str) + + global_accumulative_recall = 0 + global_accumulative_precision = 0 + total_num_gt = 0 + total_num_det = 0 + hit_str_count = 0 + hit_count = 0 + + def one_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idy): + hit_str_num = 0 + for gt_id in range(num_gt): + gt_matching_qualified_sigma_candidates = np.where( + local_sigma_table[gt_id, :] > tr) + gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[ + 0].shape[0] + gt_matching_qualified_tau_candidates = np.where( + local_tau_table[gt_id, :] > tp) + gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[ + 0].shape[0] + + det_matching_qualified_sigma_candidates = np.where( + local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]] + > tr) + det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[ + 0].shape[0] + det_matching_qualified_tau_candidates = np.where( + local_tau_table[:, gt_matching_qualified_tau_candidates[0]] > + tp) + det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[ + 0].shape[0] + + if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \ + (det_matching_num_qualified_sigma_candidates == 1) and ( + det_matching_num_qualified_tau_candidates == 1): + global_accumulative_recall = global_accumulative_recall + 1.0 + global_accumulative_precision = global_accumulative_precision + 1.0 + local_accumulative_recall = local_accumulative_recall + 1.0 + local_accumulative_precision = local_accumulative_precision + 1.0 + + gt_flag[0, gt_id] = 1 + matched_det_id = np.where(local_sigma_table[gt_id, :] > tr) + # recg start + print + "liushanshan one to one det_id = {}".format(matched_det_id) + print + "liushanshan one to one gt_id = {}".format(gt_id) + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[ + 0]] + print + "liushanshan one to one gt_str_cur = {}".format(gt_str_cur) + print + "liushanshan one to one pred_str_cur = {}".format(pred_str_cur) + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + # recg end + det_flag[0, matched_det_id] = 1 + return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num + + def one_to_many(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idy): + hit_str_num = 0 + for gt_id in range(num_gt): + # skip the following if the groundtruth was matched + if gt_flag[0, gt_id] > 0: + continue + + non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0) + num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0] + + if num_non_zero_in_sigma >= k: + ####search for all detections that overlaps with this groundtruth + qualified_tau_candidates = np.where((local_tau_table[ + gt_id, :] >= tp) & (det_flag[0, :] == 0)) + num_qualified_tau_candidates = qualified_tau_candidates[ + 0].shape[0] + + if num_qualified_tau_candidates == 1: + if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp) + and + (local_sigma_table[gt_id, qualified_tau_candidates] >= + tr)): + # became an one-to-one case + global_accumulative_recall = global_accumulative_recall + 1.0 + global_accumulative_precision = global_accumulative_precision + 1.0 + local_accumulative_recall = local_accumulative_recall + 1.0 + local_accumulative_precision = local_accumulative_precision + 1.0 + + gt_flag[0, gt_id] = 1 + det_flag[0, qualified_tau_candidates] = 1 + # recg start + print + "liushanshan one to many det_id = {}".format( + qualified_tau_candidates) + print + "liushanshan one to many gt_id = {}".format(gt_id) + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][ + qualified_tau_candidates[0].tolist()[0]] + print + "liushanshan one to many gt_str_cur = {}".format( + gt_str_cur) + print + "liushanshan one to many pred_str_cur = {}".format( + pred_str_cur) + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + # recg end + elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) + >= tr): + gt_flag[0, gt_id] = 1 + det_flag[0, qualified_tau_candidates] = 1 + # recg start + print + "liushanshan one to many det_id = {}".format( + qualified_tau_candidates) + print + "liushanshan one to many gt_id = {}".format(gt_id) + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][ + qualified_tau_candidates[0].tolist()[0]] + print + "liushanshan one to many gt_str_cur = {}".format(gt_str_cur) + print + "liushanshan one to many pred_str_cur = {}".format( + pred_str_cur) + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + # recg end + + global_accumulative_recall = global_accumulative_recall + fsc_k + global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k + + local_accumulative_recall = local_accumulative_recall + fsc_k + local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k + + return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num + + def many_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idy): + hit_str_num = 0 + for det_id in range(num_det): + # skip the following if the detection was matched + if det_flag[0, det_id] > 0: + continue + + non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0) + num_non_zero_in_tau = non_zero_in_tau[0].shape[0] + + if num_non_zero_in_tau >= k: + ####search for all detections that overlaps with this groundtruth + qualified_sigma_candidates = np.where(( + local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0)) + num_qualified_sigma_candidates = qualified_sigma_candidates[ + 0].shape[0] + + if num_qualified_sigma_candidates == 1: + if ((local_tau_table[qualified_sigma_candidates, det_id] >= + tp) and + (local_sigma_table[qualified_sigma_candidates, det_id] + >= tr)): + # became an one-to-one case + global_accumulative_recall = global_accumulative_recall + 1.0 + global_accumulative_precision = global_accumulative_precision + 1.0 + local_accumulative_recall = local_accumulative_recall + 1.0 + local_accumulative_precision = local_accumulative_precision + 1.0 + + gt_flag[0, qualified_sigma_candidates] = 1 + det_flag[0, det_id] = 1 + # recg start + print + "liushanshan many to one det_id = {}".format(det_id) + print + "liushanshan many to one gt_id = {}".format( + qualified_sigma_candidates) + pred_str_cur = global_pred_str[idy][det_id] + gt_len = len(qualified_sigma_candidates[0]) + for idx in range(gt_len): + ele_gt_id = qualified_sigma_candidates[0].tolist()[ + idx] + if not global_gt_str[idy].has_key(ele_gt_id): + continue + gt_str_cur = global_gt_str[idy][ele_gt_id] + print + "liushanshan many to one gt_str_cur = {}".format( + gt_str_cur) + print + "liushanshan many to one pred_str_cur = {}".format( + pred_str_cur) + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + break + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + break + # recg end + elif (np.sum(local_tau_table[qualified_sigma_candidates, + det_id]) >= tp): + det_flag[0, det_id] = 1 + gt_flag[0, qualified_sigma_candidates] = 1 + # recg start + print + "liushanshan many to one det_id = {}".format(det_id) + print + "liushanshan many to one gt_id = {}".format( + qualified_sigma_candidates) + pred_str_cur = global_pred_str[idy][det_id] + gt_len = len(qualified_sigma_candidates[0]) + for idx in range(gt_len): + ele_gt_id = qualified_sigma_candidates[0].tolist()[idx] + if not global_gt_str[idy].has_key(ele_gt_id): + continue + gt_str_cur = global_gt_str[idy][ele_gt_id] + print + "liushanshan many to one gt_str_cur = {}".format( + gt_str_cur) + print + "liushanshan many to one pred_str_cur = {}".format( + pred_str_cur) + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + break + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + break + else: + print + 'no match' + # recg end + + global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k + global_accumulative_precision = global_accumulative_precision + fsc_k + + local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k + local_accumulative_precision = local_accumulative_precision + fsc_k + return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num + + single_data = {} + for idx in range(len(global_sigma)): + # print(allInputs[idx]) + local_sigma_table = global_sigma[idx] + local_tau_table = global_tau[idx] + + num_gt = local_sigma_table.shape[0] + num_det = local_sigma_table.shape[1] + + total_num_gt = total_num_gt + num_gt + total_num_det = total_num_det + num_det + + local_accumulative_recall = 0 + local_accumulative_precision = 0 + gt_flag = np.zeros((1, num_gt)) + det_flag = np.zeros((1, num_det)) + + #######first check for one-to-one case########## + local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ + gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idx) + + hit_str_count += hit_str_num + #######then check for one-to-many case########## + local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ + gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idx) + hit_str_count += hit_str_num + #######then check for many-to-one case########## + local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ + gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idx) + + hit_str_count += hit_str_num + + # fid = open(fid_path, 'a+') + try: + local_precision = local_accumulative_precision / num_det + except ZeroDivisionError: + local_precision = 0 + + try: + local_recall = local_accumulative_recall / num_gt + except ZeroDivisionError: + local_recall = 0 + + try: + local_f_score = 2 * local_precision * local_recall / ( + local_precision + local_recall) + except ZeroDivisionError: + local_f_score = 0 + + # temp = ('%s: Recall=%.4f, Precision=%.4f, f_score=%.4f\n' % ( + # allInputs[idx], local_recall, local_precision, local_f_score)) + single_data['sigma'] = global_sigma + single_data['global_tau'] = global_tau + single_data['global_pred_str'] = global_pred_str + single_data['global_gt_str'] = global_gt_str + single_data["recall"] = local_recall + single_data['precision'] = local_precision + single_data['f_score'] = local_f_score + return single_data + + +def combine_results(all_data): + tr = 0.7 + tp = 0.6 + fsc_k = 0.8 + k = 2 + global_sigma = [] + global_tau = [] + global_pred_str = [] + global_gt_str = [] + for data in all_data: + global_sigma.append(data['sigma'][0]) + global_tau.append(data['global_tau'][0]) + global_pred_str.append(data['global_pred_str'][0]) + global_gt_str.append(data['global_gt_str'][0]) + + global_accumulative_recall = 0 + global_accumulative_precision = 0 + total_num_gt = 0 + total_num_det = 0 + hit_str_count = 0 + hit_count = 0 + + def one_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idy): + hit_str_num = 0 + for gt_id in range(num_gt): + gt_matching_qualified_sigma_candidates = np.where( + local_sigma_table[gt_id, :] > tr) + gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[ + 0].shape[0] + gt_matching_qualified_tau_candidates = np.where( + local_tau_table[gt_id, :] > tp) + gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[ + 0].shape[0] + + det_matching_qualified_sigma_candidates = np.where( + local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]] + > tr) + det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[ + 0].shape[0] + det_matching_qualified_tau_candidates = np.where( + local_tau_table[:, gt_matching_qualified_tau_candidates[0]] > + tp) + det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[ + 0].shape[0] + + if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \ + (det_matching_num_qualified_sigma_candidates == 1) and ( + det_matching_num_qualified_tau_candidates == 1): + global_accumulative_recall = global_accumulative_recall + 1.0 + global_accumulative_precision = global_accumulative_precision + 1.0 + local_accumulative_recall = local_accumulative_recall + 1.0 + local_accumulative_precision = local_accumulative_precision + 1.0 + + gt_flag[0, gt_id] = 1 + matched_det_id = np.where(local_sigma_table[gt_id, :] > tr) + # recg start + print + "liushanshan one to one det_id = {}".format(matched_det_id) + print + "liushanshan one to one gt_id = {}".format(gt_id) + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[ + 0]] + print + "liushanshan one to one gt_str_cur = {}".format(gt_str_cur) + print + "liushanshan one to one pred_str_cur = {}".format(pred_str_cur) + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + # recg end + det_flag[0, matched_det_id] = 1 + return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num + + def one_to_many(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idy): + hit_str_num = 0 + for gt_id in range(num_gt): + # skip the following if the groundtruth was matched + if gt_flag[0, gt_id] > 0: + continue + + non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0) + num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0] + + if num_non_zero_in_sigma >= k: + ####search for all detections that overlaps with this groundtruth + qualified_tau_candidates = np.where((local_tau_table[ + gt_id, :] >= tp) & (det_flag[0, :] == 0)) + num_qualified_tau_candidates = qualified_tau_candidates[ + 0].shape[0] + + if num_qualified_tau_candidates == 1: + if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp) + and + (local_sigma_table[gt_id, qualified_tau_candidates] >= + tr)): + # became an one-to-one case + global_accumulative_recall = global_accumulative_recall + 1.0 + global_accumulative_precision = global_accumulative_precision + 1.0 + local_accumulative_recall = local_accumulative_recall + 1.0 + local_accumulative_precision = local_accumulative_precision + 1.0 + + gt_flag[0, gt_id] = 1 + det_flag[0, qualified_tau_candidates] = 1 + # recg start + print + "liushanshan one to many det_id = {}".format( + qualified_tau_candidates) + print + "liushanshan one to many gt_id = {}".format(gt_id) + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][ + qualified_tau_candidates[0].tolist()[0]] + print + "liushanshan one to many gt_str_cur = {}".format( + gt_str_cur) + print + "liushanshan one to many pred_str_cur = {}".format( + pred_str_cur) + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + # recg end + elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) + >= tr): + gt_flag[0, gt_id] = 1 + det_flag[0, qualified_tau_candidates] = 1 + # recg start + print + "liushanshan one to many det_id = {}".format( + qualified_tau_candidates) + print + "liushanshan one to many gt_id = {}".format(gt_id) + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][ + qualified_tau_candidates[0].tolist()[0]] + print + "liushanshan one to many gt_str_cur = {}".format(gt_str_cur) + print + "liushanshan one to many pred_str_cur = {}".format( + pred_str_cur) + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + # recg end + + global_accumulative_recall = global_accumulative_recall + fsc_k + global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k + + local_accumulative_recall = local_accumulative_recall + fsc_k + local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k + + return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num + + def many_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idy): + hit_str_num = 0 + for det_id in range(num_det): + # skip the following if the detection was matched + if det_flag[0, det_id] > 0: + continue + + non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0) + num_non_zero_in_tau = non_zero_in_tau[0].shape[0] + + if num_non_zero_in_tau >= k: + ####search for all detections that overlaps with this groundtruth + qualified_sigma_candidates = np.where(( + local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0)) + num_qualified_sigma_candidates = qualified_sigma_candidates[ + 0].shape[0] + + if num_qualified_sigma_candidates == 1: + if ((local_tau_table[qualified_sigma_candidates, det_id] >= + tp) and + (local_sigma_table[qualified_sigma_candidates, det_id] + >= tr)): + # became an one-to-one case + global_accumulative_recall = global_accumulative_recall + 1.0 + global_accumulative_precision = global_accumulative_precision + 1.0 + local_accumulative_recall = local_accumulative_recall + 1.0 + local_accumulative_precision = local_accumulative_precision + 1.0 + + gt_flag[0, qualified_sigma_candidates] = 1 + det_flag[0, det_id] = 1 + # recg start + print + "liushanshan many to one det_id = {}".format(det_id) + print + "liushanshan many to one gt_id = {}".format( + qualified_sigma_candidates) + pred_str_cur = global_pred_str[idy][det_id] + gt_len = len(qualified_sigma_candidates[0]) + for idx in range(gt_len): + ele_gt_id = qualified_sigma_candidates[0].tolist()[ + idx] + if ele_gt_id not in global_gt_str[idy]: + continue + gt_str_cur = global_gt_str[idy][ele_gt_id] + print + "liushanshan many to one gt_str_cur = {}".format( + gt_str_cur) + print + "liushanshan many to one pred_str_cur = {}".format( + pred_str_cur) + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + break + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + break + # recg end + elif (np.sum(local_tau_table[qualified_sigma_candidates, + det_id]) >= tp): + det_flag[0, det_id] = 1 + gt_flag[0, qualified_sigma_candidates] = 1 + # recg start + print + "liushanshan many to one det_id = {}".format(det_id) + print + "liushanshan many to one gt_id = {}".format( + qualified_sigma_candidates) + pred_str_cur = global_pred_str[idy][det_id] + gt_len = len(qualified_sigma_candidates[0]) + for idx in range(gt_len): + ele_gt_id = qualified_sigma_candidates[0].tolist()[idx] + if not global_gt_str[idy].has_key(ele_gt_id): + continue + gt_str_cur = global_gt_str[idy][ele_gt_id] + print + "liushanshan many to one gt_str_cur = {}".format( + gt_str_cur) + print + "liushanshan many to one pred_str_cur = {}".format( + pred_str_cur) + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + break + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + break + else: + print + 'no match' + # recg end + + global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k + global_accumulative_precision = global_accumulative_precision + fsc_k + + local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k + local_accumulative_precision = local_accumulative_precision + fsc_k + return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num + + for idx in range(len(global_sigma)): + local_sigma_table = np.array(global_sigma[idx]) + local_tau_table = global_tau[idx] + + num_gt = local_sigma_table.shape[0] + num_det = local_sigma_table.shape[1] + + total_num_gt = total_num_gt + num_gt + total_num_det = total_num_det + num_det + + local_accumulative_recall = 0 + local_accumulative_precision = 0 + gt_flag = np.zeros((1, num_gt)) + det_flag = np.zeros((1, num_det)) + + #######first check for one-to-one case########## + local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ + gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idx) + + hit_str_count += hit_str_num + #######then check for one-to-many case########## + local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ + gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idx) + hit_str_count += hit_str_num + #######then check for many-to-one case########## + local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ + gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idx) + try: + recall = global_accumulative_recall / total_num_gt + except ZeroDivisionError: + recall = 0 + + try: + precision = global_accumulative_precision / total_num_det + except ZeroDivisionError: + precision = 0 + + try: + f_score = 2 * precision * recall / (precision + recall) + except ZeroDivisionError: + f_score = 0 + + try: + seqerr = 1 - float(hit_str_count) / global_accumulative_recall + except ZeroDivisionError: + seqerr = 1 + + try: + recall_e2e = float(hit_str_count) / total_num_gt + except ZeroDivisionError: + recall_e2e = 0 + + try: + precision_e2e = float(hit_str_count) / total_num_det + except ZeroDivisionError: + precision_e2e = 0 + + try: + f_score_e2e = 2 * precision_e2e * recall_e2e / ( + precision_e2e + recall_e2e) + except ZeroDivisionError: + f_score_e2e = 0 + + final = { + 'total_num_gt': total_num_gt, + 'total_num_det': total_num_det, + 'global_accumulative_recall': global_accumulative_recall, + 'hit_str_count': hit_str_count, + 'recall': recall, + 'precision': precision, + 'f_score': f_score, + 'seqerr': seqerr, + 'recall_e2e': recall_e2e, + 'precision_e2e': precision_e2e, + 'f_score_e2e': f_score_e2e + } + return final + + +# a = [1526, 642, 1565, 629, 1579, 627, 1593, 625, 1607, 623, 1620, 622, 1634, 620, 1659, 620, 1654, 681, 1631, 680, 1618, +# 681, 1606, 681, 1594, 681, 1584, 682, 1573, 685, 1542, 694] +# gt_dict = [{'points': np.array(a).reshape(-1, 2), 'text': 'MILK'}] +# pred_dict = [{'points': np.array(a), 'text': 'ccc'}, +# {'points': np.array(a), 'text': 'ccf'}] +# result = [] +# for i in range(2): +# result.append(get_socre(gt_dict, pred_dict)) +# print(111) +# a = combine_results(result) +# print(a) diff --git a/ppocr/utils/e2e_metric/polygon_fast.py b/ppocr/utils/e2e_metric/polygon_fast.py new file mode 100755 index 00000000..0173212e --- /dev/null +++ b/ppocr/utils/e2e_metric/polygon_fast.py @@ -0,0 +1,71 @@ +import numpy as np +from shapely.geometry import Polygon +#import Polygon +""" +:param det_x: [1, N] Xs of detection's vertices +:param det_y: [1, N] Ys of detection's vertices +:param gt_x: [1, N] Xs of groundtruth's vertices +:param gt_y: [1, N] Ys of groundtruth's vertices + +############## +All the calculation of 'AREA' in this script is handled by: +1) First generating a binary mask with the polygon area filled up with 1's +2) Summing up all the 1's +""" + + +def area(x, y): + polygon = Polygon(np.stack([x, y], axis=1)) + return float(polygon.area) + + +def approx_area_of_intersection(det_x, det_y, gt_x, gt_y): + """ + This helper determine if both polygons are intersecting with each others with an approximation method. + Area of intersection represented by the minimum bounding rectangular [xmin, ymin, xmax, ymax] + """ + det_ymax = np.max(det_y) + det_xmax = np.max(det_x) + det_ymin = np.min(det_y) + det_xmin = np.min(det_x) + + gt_ymax = np.max(gt_y) + gt_xmax = np.max(gt_x) + gt_ymin = np.min(gt_y) + gt_xmin = np.min(gt_x) + + all_min_ymax = np.minimum(det_ymax, gt_ymax) + all_max_ymin = np.maximum(det_ymin, gt_ymin) + + intersect_heights = np.maximum(0.0, (all_min_ymax - all_max_ymin)) + + all_min_xmax = np.minimum(det_xmax, gt_xmax) + all_max_xmin = np.maximum(det_xmin, gt_xmin) + intersect_widths = np.maximum(0.0, (all_min_xmax - all_max_xmin)) + + return intersect_heights * intersect_widths + + +def area_of_intersection(det_x, det_y, gt_x, gt_y): + p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0) + p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0) + return float(p1.intersection(p2).area) + + +def area_of_union(det_x, det_y, gt_x, gt_y): + p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0) + p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0) + return float(p1.union(p2).area) + + +def iou(det_x, det_y, gt_x, gt_y): + return area_of_intersection(det_x, det_y, gt_x, gt_y) / ( + area_of_union(det_x, det_y, gt_x, gt_y) + 1.0) + + +def iod(det_x, det_y, gt_x, gt_y): + """ + This helper determine the fraction of intersection area over detection area + """ + return area_of_intersection(det_x, det_y, gt_x, gt_y) / ( + area(det_x, det_y) + 1.0) diff --git a/ppocr/utils/e2e_metric/tttt.py b/ppocr/utils/e2e_metric/tttt.py new file mode 100644 index 00000000..91d893fd --- /dev/null +++ b/ppocr/utils/e2e_metric/tttt.py @@ -0,0 +1,881 @@ +from os import listdir +import os, sys +from scipy import io +import numpy as np +from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area +from tqdm import tqdm + +try: # python2 + range = xrange +except Exception: + # python3 + range = range +""" +Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('\n')' +""" + +# if len(sys.argv) != 4: +# print('\n usage: test.py pred_dir gt_dir savefile') +# sys.exit() +global_tp = 0 +global_fp = 0 +global_fn = 0 + +tr = 0.7 +tp = 0.6 +fsc_k = 0.8 +k = 2 + + +def get_socre(gt_dict, pred_dict): + # allInputs = listdir(input_dir) + allInputs = 1 + global_pred_str = [] + global_gt_str = [] + global_sigma = [] + global_tau = [] + + def input_reading_mod(pred_dict, input): + """This helper reads input from txt files""" + det = [] + n = len(pred_dict) + for i in range(n): + points = pred_dict[i]['points'] + text = pred_dict[i]['text'] + # for i in range(len(points)): + point = ",".join(map(str, points.reshape(-1, ))) + det.append([point, text]) + return det + + def gt_reading_mod(gt_dict, gt_id): + """This helper reads groundtruths from mat files""" + # gt_id = gt_id.split('.')[0] + gt = [] + n = len(gt_dict) + for i in range(n): + points = gt_dict[i]['points'].tolist() + h = len(points) + text = gt_dict[i]['text'] + xx = [ + np.array( + ['x:'], dtype=' 1): + gt_x = list(map(int, np.squeeze(gt[1]))) + gt_y = list(map(int, np.squeeze(gt[3]))) + for det_id, detection in enumerate(detections): + detection_orig = detection + detection = [float(x) for x in detection[0].split(',')] + # detection = detection.split(',') + detection = list(map(int, detection)) + det_x = detection[0::2] + det_y = detection[1::2] + det_gt_iou = iod(det_x, det_y, gt_x, gt_y) + if det_gt_iou > threshold: + detections[det_id] = [] + + detections[:] = [item for item in detections if item != []] + return detections + + def sigma_calculation(det_x, det_y, gt_x, gt_y): + """ + sigma = inter_area / gt_area + """ + # print(area_of_intersection(det_x, det_y, gt_x, gt_y)) + return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / + area(gt_x, gt_y)), 2) + + def tau_calculation(det_x, det_y, gt_x, gt_y): + """ + tau = inter_area / det_area + """ + # print "liushanshan det_x {}".format(det_x) + # print "liushanshan det_y {}".format(det_y) + # print "liushanshan area {}".format(area(det_x, det_y)) + # print "liushanshan tau = {}".format(np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / area(det_x, det_y)), 2)) + if area(det_x, det_y) == 0.0: + return 0 + return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / + area(det_x, det_y)), 2) + + ##############################Initialization################################### + + ############################################################################### + single_data = {} + for input_id in range(allInputs): + + if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and ( + input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and ( + input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \ + and (input_id != 'Deteval_result_non_curved.txt'): + print(input_id) + detections = input_reading_mod(pred_dict, input_id) + # print "liushanshan detections = {}".format(detections) + groundtruths = gt_reading_mod(gt_dict, input_id) + detections = detection_filtering( + detections, + groundtruths) # filters detections overlapping with DC area + dc_id = [] + for i in range(len(groundtruths)): + if groundtruths[i][5] == '#': + dc_id.append(i) + cnt = 0 + for a in dc_id: + num = a - cnt + del groundtruths[num] + cnt += 1 + + local_sigma_table = np.zeros((len(groundtruths), len(detections))) + local_tau_table = np.zeros((len(groundtruths), len(detections))) + local_pred_str = {} + local_gt_str = {} + + for gt_id, gt in enumerate(groundtruths): + if len(detections) > 0: + for det_id, detection in enumerate(detections): + detection_orig = detection + detection = [float(x) for x in detection[0].split(',')] + detection = list(map(int, detection)) + pred_seq_str = detection_orig[1].strip() + det_x = detection[0::2] + det_y = detection[1::2] + gt_x = list(map(int, np.squeeze(gt[1]))) + gt_y = list(map(int, np.squeeze(gt[3]))) + gt_seq_str = str(gt[4].tolist()[0]) + + local_sigma_table[gt_id, det_id] = sigma_calculation( + det_x, det_y, gt_x, gt_y) + local_tau_table[gt_id, det_id] = tau_calculation( + det_x, det_y, gt_x, gt_y) + local_pred_str[det_id] = pred_seq_str + local_gt_str[gt_id] = gt_seq_str + + global_sigma.append(local_sigma_table) + global_tau.append(local_tau_table) + global_pred_str.append(local_pred_str) + global_gt_str.append(local_gt_str) + print + "liushanshan global_pred_str = {}".format(global_pred_str) + print + "liushanshan global_gt_str = {}".format(global_gt_str) + single_data['sigma'] = global_sigma + single_data['global_tau'] = global_tau + single_data['global_pred_str'] = global_pred_str + single_data['global_gt_str'] = global_gt_str + return single_data + + +def combine_results(all_data): + global_sigma, global_tau, global_pred_str, global_gt_str = [], [], [], [] + for data in all_data: + global_sigma.append(data['sigma']) + global_tau.append(data['global_tau']) + global_pred_str.append(data['global_pred_str']) + global_gt_str.append(data['global_gt_str']) + global_accumulative_recall = 0 + global_accumulative_precision = 0 + total_num_gt = 0 + total_num_det = 0 + hit_str_count = 0 + hit_count = 0 + + def one_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idy): + hit_str_num = 0 + for gt_id in range(num_gt): + gt_matching_qualified_sigma_candidates = np.where( + local_sigma_table[gt_id, :] > tr) + gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[ + 0].shape[0] + gt_matching_qualified_tau_candidates = np.where( + local_tau_table[gt_id, :] > tp) + gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[ + 0].shape[0] + + det_matching_qualified_sigma_candidates = np.where( + local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]] + > tr) + det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[ + 0].shape[0] + det_matching_qualified_tau_candidates = np.where( + local_tau_table[:, gt_matching_qualified_tau_candidates[0]] > + tp) + det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[ + 0].shape[0] + + if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \ + (det_matching_num_qualified_sigma_candidates == 1) and ( + det_matching_num_qualified_tau_candidates == 1): + global_accumulative_recall = global_accumulative_recall + 1.0 + global_accumulative_precision = global_accumulative_precision + 1.0 + local_accumulative_recall = local_accumulative_recall + 1.0 + local_accumulative_precision = local_accumulative_precision + 1.0 + + gt_flag[0, gt_id] = 1 + matched_det_id = np.where(local_sigma_table[gt_id, :] > tr) + # recg start + print + "liushanshan one to one det_id = {}".format(matched_det_id) + print + "liushanshan one to one gt_id = {}".format(gt_id) + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[ + 0]] + print + "liushanshan one to one gt_str_cur = {}".format(gt_str_cur) + print + "liushanshan one to one pred_str_cur = {}".format(pred_str_cur) + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + # recg end + det_flag[0, matched_det_id] = 1 + return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num + + def one_to_many(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idy): + hit_str_num = 0 + for gt_id in range(num_gt): + # skip the following if the groundtruth was matched + if gt_flag[0, gt_id] > 0: + continue + + non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0) + num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0] + + if num_non_zero_in_sigma >= k: + ####search for all detections that overlaps with this groundtruth + qualified_tau_candidates = np.where((local_tau_table[ + gt_id, :] >= tp) & (det_flag[0, :] == 0)) + num_qualified_tau_candidates = qualified_tau_candidates[ + 0].shape[0] + + if num_qualified_tau_candidates == 1: + if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp) + and + (local_sigma_table[gt_id, qualified_tau_candidates] >= + tr)): + # became an one-to-one case + global_accumulative_recall = global_accumulative_recall + 1.0 + global_accumulative_precision = global_accumulative_precision + 1.0 + local_accumulative_recall = local_accumulative_recall + 1.0 + local_accumulative_precision = local_accumulative_precision + 1.0 + + gt_flag[0, gt_id] = 1 + det_flag[0, qualified_tau_candidates] = 1 + # recg start + print + "liushanshan one to many det_id = {}".format( + qualified_tau_candidates) + print + "liushanshan one to many gt_id = {}".format(gt_id) + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][ + qualified_tau_candidates[0].tolist()[0]] + print + "liushanshan one to many gt_str_cur = {}".format( + gt_str_cur) + print + "liushanshan one to many pred_str_cur = {}".format( + pred_str_cur) + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + # recg end + elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) + >= tr): + gt_flag[0, gt_id] = 1 + det_flag[0, qualified_tau_candidates] = 1 + # recg start + print + "liushanshan one to many det_id = {}".format( + qualified_tau_candidates) + print + "liushanshan one to many gt_id = {}".format(gt_id) + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][ + qualified_tau_candidates[0].tolist()[0]] + print + "liushanshan one to many gt_str_cur = {}".format(gt_str_cur) + print + "liushanshan one to many pred_str_cur = {}".format( + pred_str_cur) + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + # recg end + + global_accumulative_recall = global_accumulative_recall + fsc_k + global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k + + local_accumulative_recall = local_accumulative_recall + fsc_k + local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k + + return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num + + def many_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idy): + hit_str_num = 0 + for det_id in range(num_det): + # skip the following if the detection was matched + if det_flag[0, det_id] > 0: + continue + + non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0) + num_non_zero_in_tau = non_zero_in_tau[0].shape[0] + + if num_non_zero_in_tau >= k: + ####search for all detections that overlaps with this groundtruth + qualified_sigma_candidates = np.where(( + local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0)) + num_qualified_sigma_candidates = qualified_sigma_candidates[ + 0].shape[0] + + if num_qualified_sigma_candidates == 1: + if ((local_tau_table[qualified_sigma_candidates, det_id] >= + tp) and + (local_sigma_table[qualified_sigma_candidates, det_id] + >= tr)): + # became an one-to-one case + global_accumulative_recall = global_accumulative_recall + 1.0 + global_accumulative_precision = global_accumulative_precision + 1.0 + local_accumulative_recall = local_accumulative_recall + 1.0 + local_accumulative_precision = local_accumulative_precision + 1.0 + + gt_flag[0, qualified_sigma_candidates] = 1 + det_flag[0, det_id] = 1 + # recg start + print + "liushanshan many to one det_id = {}".format(det_id) + print + "liushanshan many to one gt_id = {}".format( + qualified_sigma_candidates) + pred_str_cur = global_pred_str[idy][det_id] + gt_len = len(qualified_sigma_candidates[0]) + for idx in range(gt_len): + ele_gt_id = qualified_sigma_candidates[0].tolist()[ + idx] + if not global_gt_str[idy].has_key(ele_gt_id): + continue + gt_str_cur = global_gt_str[idy][ele_gt_id] + print + "liushanshan many to one gt_str_cur = {}".format( + gt_str_cur) + print + "liushanshan many to one pred_str_cur = {}".format( + pred_str_cur) + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + break + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + break + # recg end + elif (np.sum(local_tau_table[qualified_sigma_candidates, + det_id]) >= tp): + det_flag[0, det_id] = 1 + gt_flag[0, qualified_sigma_candidates] = 1 + # recg start + print + "liushanshan many to one det_id = {}".format(det_id) + print + "liushanshan many to one gt_id = {}".format( + qualified_sigma_candidates) + pred_str_cur = global_pred_str[idy][det_id] + gt_len = len(qualified_sigma_candidates[0]) + for idx in range(gt_len): + ele_gt_id = qualified_sigma_candidates[0].tolist()[idx] + if not global_gt_str[idy].has_key(ele_gt_id): + continue + gt_str_cur = global_gt_str[idy][ele_gt_id] + print + "liushanshan many to one gt_str_cur = {}".format( + gt_str_cur) + print + "liushanshan many to one pred_str_cur = {}".format( + pred_str_cur) + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + break + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + break + else: + print + 'no match' + # recg end + + global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k + global_accumulative_precision = global_accumulative_precision + fsc_k + + local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k + local_accumulative_precision = local_accumulative_precision + fsc_k + return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num + + for idx in range(len(global_sigma)): + # print(allInputs[idx]) + local_sigma_table = np.array(global_sigma[idx]) + local_tau_table = global_tau[idx] + + num_gt = local_sigma_table.shape[0] + num_det = local_sigma_table.shape[1] + + total_num_gt = total_num_gt + num_gt + total_num_det = total_num_det + num_det + + local_accumulative_recall = 0 + local_accumulative_precision = 0 + gt_flag = np.zeros((1, num_gt)) + det_flag = np.zeros((1, num_det)) + + #######first check for one-to-one case########## + local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ + gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idx) + + hit_str_count += hit_str_num + #######then check for one-to-many case########## + local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ + gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idx) + hit_str_count += hit_str_num + #######then check for many-to-one case########## + local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ + gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idx) + + try: + recall = global_accumulative_recall / total_num_gt + except ZeroDivisionError: + recall = 0 + + try: + precision = global_accumulative_precision / total_num_det + except ZeroDivisionError: + precision = 0 + + try: + f_score = 2 * precision * recall / (precision + recall) + except ZeroDivisionError: + f_score = 0 + + try: + seqerr = 1 - float(hit_str_count) / global_accumulative_recall + except ZeroDivisionError: + seqerr = 1 + + try: + recall_e2e = float(hit_str_count) / total_num_gt + except ZeroDivisionError: + recall_e2e = 0 + + try: + precision_e2e = float(hit_str_count) / total_num_det + except ZeroDivisionError: + precision_e2e = 0 + + try: + f_score_e2e = 2 * precision_e2e * recall_e2e / ( + precision_e2e + recall_e2e) + except ZeroDivisionError: + f_score_e2e = 0 + + final = { + 'total_num_gt': total_num_gt, + 'total_num_det': total_num_det, + 'global_accumulative_recall': global_accumulative_recall, + 'hit_str_count': hit_str_count, + 'recall': recall, + 'precision': precision, + 'f_score': f_score, + 'seqerr': seqerr, + 'recall_e2e': recall_e2e, + 'precision_e2e': precision_e2e, + 'f_score_e2e': f_score_e2e + } + return final + + +# def combine_results(all_data): +# tr = 0.7 +# tp = 0.6 +# fsc_k = 0.8 +# k = 2 +# global_sigma = [] +# global_tau = [] +# global_pred_str = [] +# global_gt_str = [] +# for data in all_data: +# global_sigma.append(data['sigma']) +# global_tau.append(data['global_tau']) +# global_pred_str.append(data['global_pred_str']) +# global_gt_str.append(data['global_gt_str']) +# +# global_accumulative_recall = 0 +# global_accumulative_precision = 0 +# total_num_gt = 0 +# total_num_det = 0 +# hit_str_count = 0 +# hit_count = 0 +# +# def one_to_one(local_sigma_table, local_tau_table, local_accumulative_recall, +# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, +# gt_flag, det_flag, idy): +# hit_str_num = 0 +# for gt_id in range(num_gt): +# gt_matching_qualified_sigma_candidates = np.where(local_sigma_table[gt_id, :] > tr) +# gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[0].shape[0] +# gt_matching_qualified_tau_candidates = np.where(local_tau_table[gt_id, :] > tp) +# gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[0].shape[0] +# +# det_matching_qualified_sigma_candidates = np.where( +# local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]] > tr) +# det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[0].shape[0] +# det_matching_qualified_tau_candidates = np.where( +# local_tau_table[:, gt_matching_qualified_tau_candidates[0]] > tp) +# det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[0].shape[0] +# +# if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \ +# (det_matching_num_qualified_sigma_candidates == 1) and ( +# det_matching_num_qualified_tau_candidates == 1): +# global_accumulative_recall = global_accumulative_recall + 1.0 +# global_accumulative_precision = global_accumulative_precision + 1.0 +# local_accumulative_recall = local_accumulative_recall + 1.0 +# local_accumulative_precision = local_accumulative_precision + 1.0 +# +# gt_flag[0, gt_id] = 1 +# matched_det_id = np.where(local_sigma_table[gt_id, :] > tr) +# # recg start +# print +# "liushanshan one to one det_id = {}".format(matched_det_id) +# print +# "liushanshan one to one gt_id = {}".format(gt_id) +# gt_str_cur = global_gt_str[idy][gt_id] +# pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[0]] +# print +# "liushanshan one to one gt_str_cur = {}".format(gt_str_cur) +# print +# "liushanshan one to one pred_str_cur = {}".format(pred_str_cur) +# if pred_str_cur == gt_str_cur: +# hit_str_num += 1 +# else: +# if pred_str_cur.lower() == gt_str_cur.lower(): +# hit_str_num += 1 +# # recg end +# det_flag[0, matched_det_id] = 1 +# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num +# +# def one_to_many(local_sigma_table, local_tau_table, local_accumulative_recall, +# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, +# gt_flag, det_flag, idy): +# hit_str_num = 0 +# for gt_id in range(num_gt): +# # skip the following if the groundtruth was matched +# if gt_flag[0, gt_id] > 0: +# continue +# +# non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0) +# num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0] +# +# if num_non_zero_in_sigma >= k: +# ####search for all detections that overlaps with this groundtruth +# qualified_tau_candidates = np.where((local_tau_table[gt_id, :] >= tp) & (det_flag[0, :] == 0)) +# num_qualified_tau_candidates = qualified_tau_candidates[0].shape[0] +# +# if num_qualified_tau_candidates == 1: +# if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp) and ( +# local_sigma_table[gt_id, qualified_tau_candidates] >= tr)): +# # became an one-to-one case +# global_accumulative_recall = global_accumulative_recall + 1.0 +# global_accumulative_precision = global_accumulative_precision + 1.0 +# local_accumulative_recall = local_accumulative_recall + 1.0 +# local_accumulative_precision = local_accumulative_precision + 1.0 +# +# gt_flag[0, gt_id] = 1 +# det_flag[0, qualified_tau_candidates] = 1 +# # recg start +# print +# "liushanshan one to many det_id = {}".format(qualified_tau_candidates) +# print +# "liushanshan one to many gt_id = {}".format(gt_id) +# gt_str_cur = global_gt_str[idy][gt_id] +# pred_str_cur = global_pred_str[idy][qualified_tau_candidates[0].tolist()[0]] +# print +# "liushanshan one to many gt_str_cur = {}".format(gt_str_cur) +# print +# "liushanshan one to many pred_str_cur = {}".format(pred_str_cur) +# if pred_str_cur == gt_str_cur: +# hit_str_num += 1 +# else: +# if pred_str_cur.lower() == gt_str_cur.lower(): +# hit_str_num += 1 +# # recg end +# elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) >= tr): +# gt_flag[0, gt_id] = 1 +# det_flag[0, qualified_tau_candidates] = 1 +# # recg start +# print +# "liushanshan one to many det_id = {}".format(qualified_tau_candidates) +# print +# "liushanshan one to many gt_id = {}".format(gt_id) +# gt_str_cur = global_gt_str[idy][gt_id] +# pred_str_cur = global_pred_str[idy][qualified_tau_candidates[0].tolist()[0]] +# print +# "liushanshan one to many gt_str_cur = {}".format(gt_str_cur) +# print +# "liushanshan one to many pred_str_cur = {}".format(pred_str_cur) +# if pred_str_cur == gt_str_cur: +# hit_str_num += 1 +# else: +# if pred_str_cur.lower() == gt_str_cur.lower(): +# hit_str_num += 1 +# # recg end +# +# global_accumulative_recall = global_accumulative_recall + fsc_k +# global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k +# +# local_accumulative_recall = local_accumulative_recall + fsc_k +# local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k +# +# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num +# +# def many_to_one(local_sigma_table, local_tau_table, local_accumulative_recall, +# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, +# gt_flag, det_flag, idy): +# hit_str_num = 0 +# for det_id in range(num_det): +# # skip the following if the detection was matched +# if det_flag[0, det_id] > 0: +# continue +# +# non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0) +# num_non_zero_in_tau = non_zero_in_tau[0].shape[0] +# +# if num_non_zero_in_tau >= k: +# ####search for all detections that overlaps with this groundtruth +# qualified_sigma_candidates = np.where((local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0)) +# num_qualified_sigma_candidates = qualified_sigma_candidates[0].shape[0] +# +# if num_qualified_sigma_candidates == 1: +# if ((local_tau_table[qualified_sigma_candidates, det_id] >= tp) and ( +# local_sigma_table[qualified_sigma_candidates, det_id] >= tr)): +# # became an one-to-one case +# global_accumulative_recall = global_accumulative_recall + 1.0 +# global_accumulative_precision = global_accumulative_precision + 1.0 +# local_accumulative_recall = local_accumulative_recall + 1.0 +# local_accumulative_precision = local_accumulative_precision + 1.0 +# +# gt_flag[0, qualified_sigma_candidates] = 1 +# det_flag[0, det_id] = 1 +# # recg start +# print +# "liushanshan many to one det_id = {}".format(det_id) +# print +# "liushanshan many to one gt_id = {}".format(qualified_sigma_candidates) +# pred_str_cur = global_pred_str[idy][det_id] +# gt_len = len(qualified_sigma_candidates[0]) +# for idx in range(gt_len): +# ele_gt_id = qualified_sigma_candidates[0].tolist()[idx] +# if ele_gt_id not in global_gt_str[idy]: +# continue +# gt_str_cur = global_gt_str[idy][ele_gt_id] +# print +# "liushanshan many to one gt_str_cur = {}".format(gt_str_cur) +# print +# "liushanshan many to one pred_str_cur = {}".format(pred_str_cur) +# if pred_str_cur == gt_str_cur: +# hit_str_num += 1 +# break +# else: +# if pred_str_cur.lower() == gt_str_cur.lower(): +# hit_str_num += 1 +# break +# # recg end +# elif (np.sum(local_tau_table[qualified_sigma_candidates, det_id]) >= tp): +# det_flag[0, det_id] = 1 +# gt_flag[0, qualified_sigma_candidates] = 1 +# # recg start +# print +# "liushanshan many to one det_id = {}".format(det_id) +# print +# "liushanshan many to one gt_id = {}".format(qualified_sigma_candidates) +# pred_str_cur = global_pred_str[idy][det_id] +# gt_len = len(qualified_sigma_candidates[0]) +# for idx in range(gt_len): +# ele_gt_id = qualified_sigma_candidates[0].tolist()[idx] +# if not global_gt_str[idy].has_key(ele_gt_id): +# continue +# gt_str_cur = global_gt_str[idy][ele_gt_id] +# print +# "liushanshan many to one gt_str_cur = {}".format(gt_str_cur) +# print +# "liushanshan many to one pred_str_cur = {}".format(pred_str_cur) +# if pred_str_cur == gt_str_cur: +# hit_str_num += 1 +# break +# else: +# if pred_str_cur.lower() == gt_str_cur.lower(): +# hit_str_num += 1 +# break +# else: +# print +# 'no match' +# # recg end +# +# global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k +# global_accumulative_precision = global_accumulative_precision + fsc_k +# +# local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k +# local_accumulative_precision = local_accumulative_precision + fsc_k +# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num +# +# for idx in range(len(global_sigma)): +# local_sigma_table = np.array(global_sigma[idx]) +# local_tau_table = np.array(global_tau[idx]) +# +# num_gt = local_sigma_table.shape[0] +# num_det = local_sigma_table.shape[1] +# +# total_num_gt = total_num_gt + num_gt +# total_num_det = total_num_det + num_det +# +# local_accumulative_recall = 0 +# local_accumulative_precision = 0 +# gt_flag = np.zeros((1, num_gt)) +# det_flag = np.zeros((1, num_det)) +# +# #######first check for one-to-one case########## +# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ +# gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table, +# local_accumulative_recall, local_accumulative_precision, +# global_accumulative_recall, global_accumulative_precision, +# gt_flag, det_flag, idx) +# +# hit_str_count += hit_str_num +# #######then check for one-to-many case########## +# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ +# gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table, +# local_accumulative_recall, local_accumulative_precision, +# global_accumulative_recall, global_accumulative_precision, +# gt_flag, det_flag, idx) +# hit_str_count += hit_str_num +# #######then check for many-to-one case########## +# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ +# gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table, +# local_accumulative_recall, local_accumulative_precision, +# global_accumulative_recall, global_accumulative_precision, +# gt_flag, det_flag, idx) +# try: +# recall = global_accumulative_recall / total_num_gt +# except ZeroDivisionError: +# recall = 0 +# +# try: +# precision = global_accumulative_precision / total_num_det +# except ZeroDivisionError: +# precision = 0 +# +# try: +# f_score = 2 * precision * recall / (precision + recall) +# except ZeroDivisionError: +# f_score = 0 +# +# try: +# seqerr = 1 - float(hit_str_count) / global_accumulative_recall +# except ZeroDivisionError: +# seqerr = 1 +# +# try: +# recall_e2e = float(hit_str_count) / total_num_gt +# except ZeroDivisionError: +# recall_e2e = 0 +# +# try: +# precision_e2e = float(hit_str_count) / total_num_det +# except ZeroDivisionError: +# precision_e2e = 0 +# +# try: +# f_score_e2e = 2 * precision_e2e * recall_e2e / (precision_e2e + recall_e2e) +# except ZeroDivisionError: +# f_score_e2e = 0 +# +# final = { +# 'total_num_gt': total_num_gt, +# 'total_num_det': total_num_det, +# 'global_accumulative_recall': global_accumulative_recall, +# 'hit_str_count': hit_str_count, +# 'recall': recall, +# 'precision': precision, +# 'f_score': f_score, +# 'seqerr': seqerr, +# 'recall_e2e': recall_e2e, +# 'precision_e2e': precision_e2e, +# 'f_score_e2e': f_score_e2e +# } +# return final + +a = [ + 1526, 642, 1565, 629, 1579, 627, 1593, 625, 1607, 623, 1620, 622, 1634, 620, + 1659, 620, 1654, 681, 1631, 680, 1618, 681, 1606, 681, 1594, 681, 1584, 682, + 1573, 685, 1542, 694 +] +gt_dict = [{'points': np.array(a).reshape(-1, 2), 'text': 'MILK'}] +pred_dict = [{ + 'points': np.array(a), + 'text': 'ccc' +}, { + 'points': np.array(a), + 'text': 'ccf' +}] +result = [] +result.append(get_socre(gt_dict, gt_dict)) +a = combine_results(result) +print(a) diff --git a/ppocr/utils/e2e_utils/extract_textpoint.py b/ppocr/utils/e2e_utils/extract_textpoint.py new file mode 100644 index 00000000..96ebf02e --- /dev/null +++ b/ppocr/utils/e2e_utils/extract_textpoint.py @@ -0,0 +1,532 @@ +"""Contains various CTC decoders.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import cv2 +import time +import math + +import numpy as np +from itertools import groupby +from ppocr.utils.e2e_utils.ski_thin import thin + + +def softmax(logits): + """ + logits: N x d + """ + max_value = np.max(logits, axis=1, keepdims=True) + exp = np.exp(logits - max_value) + exp_sum = np.sum(exp, axis=1, keepdims=True) + dist = exp / exp_sum + return dist + + +def get_keep_pos_idxs(labels, remove_blank=None): + """ + Remove duplicate and get pos idxs of keep items. + The value of keep_blank should be [None, 95]. + """ + duplicate_len_list = [] + keep_pos_idx_list = [] + keep_char_idx_list = [] + for k, v_ in groupby(labels): + current_len = len(list(v_)) + if k != remove_blank: + current_idx = int(sum(duplicate_len_list) + current_len // 2) + keep_pos_idx_list.append(current_idx) + keep_char_idx_list.append(k) + duplicate_len_list.append(current_len) + return keep_char_idx_list, keep_pos_idx_list + + +def remove_blank(labels, blank=0): + new_labels = [x for x in labels if x != blank] + return new_labels + + +def insert_blank(labels, blank=0): + new_labels = [blank] + for l in labels: + new_labels += [l, blank] + return new_labels + + +def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True): + """ + CTC greedy (best path) decoder. + """ + raw_str = np.argmax(np.array(probs_seq), axis=1) + remove_blank_in_pos = None if keep_blank_in_idxs else blank + dedup_str, keep_idx_list = get_keep_pos_idxs( + raw_str, remove_blank=remove_blank_in_pos) + dst_str = remove_blank(dedup_str, blank=blank) + return dst_str, keep_idx_list + + +def instance_ctc_greedy_decoder(gather_info, + logits_map, + keep_blank_in_idxs=True): + """ + gather_info: [[x, y], [x, y] ...] + logits_map: H x W X (n_chars + 1) + """ + _, _, C = logits_map.shape + ys, xs = zip(*gather_info) + logits_seq = logits_map[list(ys), list(xs)] # n x 96 + probs_seq = softmax(logits_seq) + dst_str, keep_idx_list = ctc_greedy_decoder( + probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs) + keep_gather_list = [gather_info[idx] for idx in keep_idx_list] + return dst_str, keep_gather_list + + +def ctc_decoder_for_image(gather_info_list, logits_map, + keep_blank_in_idxs=True): + """ + CTC decoder using multiple processes. + """ + decoder_results = [] + for gather_info in gather_info_list: + res = instance_ctc_greedy_decoder( + gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs) + decoder_results.append(res) + return decoder_results + + +def sort_with_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + + def sort_part_with_direction(pos_list, point_direction): + pos_list = np.array(pos_list).reshape(-1, 2) + point_direction = np.array(point_direction).reshape(-1, 2) + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist() + sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist() + return sorted_list, sorted_direction + + pos_list = np.array(pos_list).reshape(-1, 2) + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + sorted_point, sorted_direction = sort_part_with_direction(pos_list, + point_direction) + + point_num = len(sorted_point) + if point_num >= 16: + middle_num = point_num // 2 + first_part_point = sorted_point[:middle_num] + first_point_direction = sorted_direction[:middle_num] + sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction( + first_part_point, first_point_direction) + + last_part_point = sorted_point[middle_num:] + last_point_direction = sorted_direction[middle_num:] + sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction( + last_part_point, last_point_direction) + sorted_point = sorted_fist_part_point + sorted_last_part_point + sorted_direction = sorted_fist_part_direction + sorted_last_part_direction + + return sorted_point, np.array(sorted_direction) + + +def add_id(pos_list, image_id=0): + """ + Add id for gather feature, for inference. + """ + new_list = [] + for item in pos_list: + new_list.append((image_id, item[0], item[1])) + return new_list + + +def sort_and_expand_with_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + h, w, _ = f_direction.shape + sorted_list, point_direction = sort_with_direction(pos_list, f_direction) + + # expand along + point_num = len(sorted_list) + sub_direction_len = max(point_num // 3, 2) + left_direction = point_direction[:sub_direction_len, :] + right_dirction = point_direction[point_num - sub_direction_len:, :] + + left_average_direction = -np.mean(left_direction, axis=0, keepdims=True) + left_average_len = np.linalg.norm(left_average_direction) + left_start = np.array(sorted_list[0]) + left_step = left_average_direction / (left_average_len + 1e-6) + + right_average_direction = np.mean(right_dirction, axis=0, keepdims=True) + right_average_len = np.linalg.norm(right_average_direction) + right_step = right_average_direction / (right_average_len + 1e-6) + right_start = np.array(sorted_list[-1]) + + append_num = max( + int((left_average_len + right_average_len) / 2.0 * 0.15), 1) + left_list = [] + right_list = [] + for i in range(append_num): + ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ly < h and lx < w and (ly, lx) not in left_list: + left_list.append((ly, lx)) + ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ry < h and rx < w and (ry, rx) not in right_list: + right_list.append((ry, rx)) + + all_list = left_list[::-1] + sorted_list + right_list + return all_list + + +def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + binary_tcl_map: h x w + """ + h, w, _ = f_direction.shape + sorted_list, point_direction = sort_with_direction(pos_list, f_direction) + + # expand along + point_num = len(sorted_list) + sub_direction_len = max(point_num // 3, 2) + left_direction = point_direction[:sub_direction_len, :] + right_dirction = point_direction[point_num - sub_direction_len:, :] + + left_average_direction = -np.mean(left_direction, axis=0, keepdims=True) + left_average_len = np.linalg.norm(left_average_direction) + left_start = np.array(sorted_list[0]) + left_step = left_average_direction / (left_average_len + 1e-6) + + right_average_direction = np.mean(right_dirction, axis=0, keepdims=True) + right_average_len = np.linalg.norm(right_average_direction) + right_step = right_average_direction / (right_average_len + 1e-6) + right_start = np.array(sorted_list[-1]) + + append_num = max( + int((left_average_len + right_average_len) / 2.0 * 0.15), 1) + max_append_num = 2 * append_num + + left_list = [] + right_list = [] + for i in range(max_append_num): + ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ly < h and lx < w and (ly, lx) not in left_list: + if binary_tcl_map[ly, lx] > 0.5: + left_list.append((ly, lx)) + else: + break + + for i in range(max_append_num): + ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ry < h and rx < w and (ry, rx) not in right_list: + if binary_tcl_map[ry, rx] > 0.5: + right_list.append((ry, rx)) + else: + break + + all_list = left_list[::-1] + sorted_list + right_list + return all_list + + +def generate_pivot_list_curved(p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_expand=True, + is_backbone=False, + image_id=0): + """ + return center point and end point of TCL instance; filter with the char maps; + """ + p_score = p_score[0] + f_direction = f_direction.transpose(1, 2, 0) + p_tcl_map = (p_score > score_thresh) * 1.0 + skeleton_map = thin(p_tcl_map) + instance_count, instance_label_map = cv2.connectedComponents( + skeleton_map.astype(np.uint8), connectivity=8) + + # get TCL Instance + all_pos_yxs = [] + center_pos_yxs = [] + end_points_yxs = [] + instance_center_pos_yxs = [] + if instance_count > 0: + for instance_id in range(1, instance_count): + pos_list = [] + ys, xs = np.where(instance_label_map == instance_id) + pos_list = list(zip(ys, xs)) + + ### FIX-ME, eliminate outlier + if len(pos_list) < 3: + continue + + if is_expand: + pos_list_sorted = sort_and_expand_with_direction_v2( + pos_list, f_direction, p_tcl_map) + else: + pos_list_sorted, _ = sort_with_direction(pos_list, f_direction) + all_pos_yxs.append(pos_list_sorted) + + # use decoder to filter backgroud points. + p_char_maps = p_char_maps.transpose([1, 2, 0]) + decode_res = ctc_decoder_for_image( + all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True) + for decoded_str, keep_yxs_list in decode_res: + if is_backbone: + keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id) + instance_center_pos_yxs.append(keep_yxs_list_with_id) + else: + end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1])) + center_pos_yxs.extend(keep_yxs_list) + + if is_backbone: + return instance_center_pos_yxs + else: + return center_pos_yxs, end_points_yxs + + +def generate_pivot_list_horizontal(p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_backbone=False, + image_id=0): + """ + return center point and end point of TCL instance; filter with the char maps; + """ + p_score = p_score[0] + f_direction = f_direction.transpose(1, 2, 0) + p_tcl_map_bi = (p_score > score_thresh) * 1.0 + instance_count, instance_label_map = cv2.connectedComponents( + p_tcl_map_bi.astype(np.uint8), connectivity=8) + + # get TCL Instance + all_pos_yxs = [] + center_pos_yxs = [] + end_points_yxs = [] + instance_center_pos_yxs = [] + + if instance_count > 0: + for instance_id in range(1, instance_count): + pos_list = [] + ys, xs = np.where(instance_label_map == instance_id) + pos_list = list(zip(ys, xs)) + + ### FIX-ME, eliminate outlier + if len(pos_list) < 5: + continue + + # add rule here + main_direction = extract_main_direction(pos_list, + f_direction) # y x + reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x + is_h_angle = abs(np.sum( + main_direction * reference_directin)) < math.cos(math.pi / 180 * + 70) + + point_yxs = np.array(pos_list) + max_y, max_x = np.max(point_yxs, axis=0) + min_y, min_x = np.min(point_yxs, axis=0) + is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x) + + pos_list_final = [] + if is_h_len: + xs = np.unique(xs) + for x in xs: + ys = instance_label_map[:, x].copy().reshape((-1, )) + y = int(np.where(ys == instance_id)[0].mean()) + pos_list_final.append((y, x)) + else: + ys = np.unique(ys) + for y in ys: + xs = instance_label_map[y, :].copy().reshape((-1, )) + x = int(np.where(xs == instance_id)[0].mean()) + pos_list_final.append((y, x)) + + pos_list_sorted, _ = sort_with_direction(pos_list_final, + f_direction) + all_pos_yxs.append(pos_list_sorted) + + # use decoder to filter backgroud points. + p_char_maps = p_char_maps.transpose([1, 2, 0]) + decode_res = ctc_decoder_for_image( + all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True) + for decoded_str, keep_yxs_list in decode_res: + if is_backbone: + keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id) + instance_center_pos_yxs.append(keep_yxs_list_with_id) + else: + end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1])) + center_pos_yxs.extend(keep_yxs_list) + + if is_backbone: + return instance_center_pos_yxs + else: + return center_pos_yxs, end_points_yxs + + +def generate_pivot_list(p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_backbone=False, + is_curved=True, + image_id=0): + """ + Warp all the function together. + """ + if is_curved: + return generate_pivot_list_curved( + p_score, + p_char_maps, + f_direction, + score_thresh=score_thresh, + is_expand=True, + is_backbone=is_backbone, + image_id=image_id) + else: + return generate_pivot_list_horizontal( + p_score, + p_char_maps, + f_direction, + score_thresh=score_thresh, + is_backbone=is_backbone, + image_id=image_id) + + +# for refine module +def extract_main_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + pos_list = np.array(pos_list) + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] + point_direction = point_direction[:, ::-1] # x, y -> y, x + average_direction = np.mean(point_direction, axis=0, keepdims=True) + average_direction = average_direction / ( + np.linalg.norm(average_direction) + 1e-6) + return average_direction + + +def sort_by_direction_with_image_id_deprecated(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[id, y, x], [id, y, x], [id, y, x] ...] + """ + pos_list_full = np.array(pos_list).reshape(-1, 3) + pos_list = pos_list_full[:, 1:] + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist() + return sorted_list + + +def sort_by_direction_with_image_id(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + + def sort_part_with_direction(pos_list_full, point_direction): + pos_list_full = np.array(pos_list_full).reshape(-1, 3) + pos_list = pos_list_full[:, 1:] + point_direction = np.array(point_direction).reshape(-1, 2) + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist() + sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist() + return sorted_list, sorted_direction + + pos_list = np.array(pos_list).reshape(-1, 3) + point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + sorted_point, sorted_direction = sort_part_with_direction(pos_list, + point_direction) + + point_num = len(sorted_point) + if point_num >= 16: + middle_num = point_num // 2 + first_part_point = sorted_point[:middle_num] + first_point_direction = sorted_direction[:middle_num] + sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction( + first_part_point, first_point_direction) + + last_part_point = sorted_point[middle_num:] + last_point_direction = sorted_direction[middle_num:] + sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction( + last_part_point, last_point_direction) + sorted_point = sorted_fist_part_point + sorted_last_part_point + sorted_direction = sorted_fist_part_direction + sorted_last_part_direction + + return sorted_point + + +def generate_pivot_list_tt_inference(p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_backbone=False, + is_curved=True, + image_id=0): + """ + return center point and end point of TCL instance; filter with the char maps; + """ + p_score = p_score[0] + f_direction = f_direction.transpose(1, 2, 0) + p_tcl_map = (p_score > score_thresh) * 1.0 + skeleton_map = thin(p_tcl_map) + instance_count, instance_label_map = cv2.connectedComponents( + skeleton_map.astype(np.uint8), connectivity=8) + + # get TCL Instance + all_pos_yxs = [] + if instance_count > 0: + for instance_id in range(1, instance_count): + pos_list = [] + ys, xs = np.where(instance_label_map == instance_id) + pos_list = list(zip(ys, xs)) + ### FIX-ME, eliminate outlier + if len(pos_list) < 3: + continue + pos_list_sorted = sort_and_expand_with_direction_v2( + pos_list, f_direction, p_tcl_map) + # pos_list_sorted, _ = sort_with_direction(pos_list, f_direction) + pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id) + all_pos_yxs.append(pos_list_sorted_with_id) + return all_pos_yxs + + +if __name__ == '__main__': + np.random.seed(0) + import time + + logits_map = np.random.random([10, 20, 33]) + # a list of [x, y] + instance_gather_info_1 = [(2, 3), (2, 4), (3, 5)] + instance_gather_info_2 = [(15, 6), (15, 7), (18, 8)] + instance_gather_info_3 = [(8, 8), (8, 8), (8, 8)] + gather_info_list = [ + instance_gather_info_1, instance_gather_info_2, instance_gather_info_3 + ] + + time0 = time.time() + res = ctc_decoder_for_image( + gather_info_list, logits_map, keep_blank_in_idxs=True) + print(res) + print('cost {}'.format(time.time() - time0)) + print('--' * 20) diff --git a/ppocr/utils/e2e_utils/ski_thin.py b/ppocr/utils/e2e_utils/ski_thin.py new file mode 100644 index 00000000..dba2afdd --- /dev/null +++ b/ppocr/utils/e2e_utils/ski_thin.py @@ -0,0 +1,126 @@ +""" +Algorithms for computing the skeleton of a binary image +""" + +import numpy as np +from scipy import ndimage as ndi + +G123_LUT = np.array( + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, + 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, + 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, + 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0 + ], + dtype=np.bool) + +G123P_LUT = np.array( + [ + 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, + 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + ], + dtype=np.bool) + + +def thin(image, max_iter=None): + """ + Perform morphological thinning of a binary image. + Parameters + ---------- + image : binary (M, N) ndarray + The image to be thinned. + max_iter : int, number of iterations, optional + Regardless of the value of this parameter, the thinned image + is returned immediately if an iteration produces no change. + If this parameter is specified it thus sets an upper bound on + the number of iterations performed. + Returns + ------- + out : ndarray of bool + Thinned image. + See also + -------- + skeletonize, medial_axis + Notes + ----- + This algorithm [1]_ works by making multiple passes over the image, + removing pixels matching a set of criteria designed to thin + connected regions while preserving eight-connected components and + 2 x 2 squares [2]_. In each of the two sub-iterations the algorithm + correlates the intermediate skeleton image with a neighborhood mask, + then looks up each neighborhood in a lookup table indicating whether + the central pixel should be deleted in that sub-iteration. + References + ---------- + .. [1] Z. Guo and R. W. Hall, "Parallel thinning with + two-subiteration algorithms," Comm. ACM, vol. 32, no. 3, + pp. 359-373, 1989. :DOI:`10.1145/62065.62074` + .. [2] Lam, L., Seong-Whan Lee, and Ching Y. Suen, "Thinning + Methodologies-A Comprehensive Survey," IEEE Transactions on + Pattern Analysis and Machine Intelligence, Vol 14, No. 9, + p. 879, 1992. :DOI:`10.1109/34.161346` + Examples + -------- + >>> square = np.zeros((7, 7), dtype=np.uint8) + >>> square[1:-1, 2:-2] = 1 + >>> square[0, 1] = 1 + >>> square + array([[0, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0]], dtype=uint8) + >>> skel = thin(square) + >>> skel.astype(np.uint8) + array([[0, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0]], dtype=uint8) + """ + # convert image to uint8 with values in {0, 1} + skel = np.asanyarray(image, dtype=bool).astype(np.uint8) + + # neighborhood mask + mask = np.array([[8, 4, 2], [16, 0, 1], [32, 64, 128]], dtype=np.uint8) + + # iterate until convergence, up to the iteration limit + max_iter = max_iter or np.inf + n_iter = 0 + n_pts_old, n_pts_new = np.inf, np.sum(skel) + while n_pts_old != n_pts_new and n_iter < max_iter: + n_pts_old = n_pts_new + + # perform the two "subiterations" described in the paper + for lut in [G123_LUT, G123P_LUT]: + # correlate image with neighborhood mask + N = ndi.correlate(skel, mask, mode='constant') + # take deletion decision from this subiteration's LUT + D = np.take(lut, N) + # perform deletion + skel[D] = 0 + + n_pts_new = np.sum(skel) # count points after thinning + n_iter += 1 + + return skel.astype(np.bool) diff --git a/ppocr/utils/e2e_utils/visual.py b/ppocr/utils/e2e_utils/visual.py new file mode 100644 index 00000000..4c96e5c7 --- /dev/null +++ b/ppocr/utils/e2e_utils/visual.py @@ -0,0 +1,343 @@ +import os +import numpy as np +import cv2 +import time + + +def visualize_e2e_result(im_fn, poly_list, seq_strs, src_im): + """ + """ + result_path = './out' + im_basename = os.path.basename(im_fn) + im_prefix = im_basename[:im_basename.rfind('.')] + vis_det_img = src_im.copy() + valid_set = 'partvgg' + gt_dir = "/Users/hongyongjie/Downloads/part_vgg_synth/train" + text_path = os.path.join(gt_dir, im_prefix + '.txt') + fid = open(text_path, 'r') + lines = [line.strip() for line in fid.readlines()] + for line in lines: + if valid_set == 'partvgg': + tokens = line.strip().split('\t')[0].split(',') + # tokens = line.strip().split(',') + coords = tokens[:] + coords = list(map(float, coords)) + gt_poly = np.array(coords).reshape(1, 4, 2) + elif valid_set == 'totaltext': + tokens = line.strip().split('\t')[0].split(',') + coords = tokens[:] + coords_len = len(coords) / 2 + coords = list(map(float, coords)) + gt_poly = np.array(coords).reshape(1, coords_len, 2) + cv2.polylines( + vis_det_img, + np.array(gt_poly).astype(np.int32), + isClosed=True, + color=(255, 0, 0), + thickness=2) + + for detected_poly, recognized_str in zip(poly_list, seq_strs): + cv2.polylines( + vis_det_img, + np.array(detected_poly[np.newaxis, ...]).astype(np.int32), + isClosed=True, + color=(0, 0, 255), + thickness=2) + cv2.putText( + vis_det_img, + recognized_str, + org=(int(detected_poly[0, 0]), int(detected_poly[0, 1])), + fontFace=cv2.FONT_HERSHEY_COMPLEX, + fontScale=0.7, + color=(0, 255, 0), + thickness=1) + + if not os.path.exists(result_path): + os.makedirs(result_path) + cv2.imwrite("{}/{}_detection.jpg".format(result_path, im_prefix), + vis_det_img) + + +def visualization_output(src_image, + f_tcl, + f_chars, + output_dir, + image_prefix=None): + """ + """ + # restore BGR image, CHW -> HWC + im_mean = [0.485, 0.456, 0.406] + im_std = [0.229, 0.224, 0.225] + + im_mean = np.array(im_mean).reshape((3, 1, 1)) + im_std = np.array(im_std).reshape((3, 1, 1)) + src_image *= im_std + src_image += im_mean + src_image = src_image.transpose([1, 2, 0]) + src_image = src_image[:, :, ::-1] * 255 # BGR -> RGB + H, W, _ = src_image.shape + + file_prefix = image_prefix if image_prefix is not None else str( + int(time.time() * 1000)) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # visualization f_tcl + tcl_file_name = os.path.join(output_dir, file_prefix + '_0_tcl.jpg') + vis_tcl_img = src_image.copy() + f_tcl_resized = cv2.resize(f_tcl, dsize=(W, H)) + vis_tcl_img[:, :, 1] = f_tcl_resized * 255 + cv2.imwrite(tcl_file_name, vis_tcl_img) + + # visualization char maps + vis_char_img = src_image.copy() + # CHW -> HWC + char_file_name = os.path.join(output_dir, file_prefix + '_1_chars.jpg') + f_chars = np.argmax(f_chars, axis=2)[:, :, np.newaxis].astype('float32') + f_chars[f_chars < 95] = 1.0 + f_chars[f_chars == 95] = 0.0 + f_chars_resized = cv2.resize(f_chars, dsize=(W, H)) + vis_char_img[:, :, 1] = f_chars_resized * 255 + cv2.imwrite(char_file_name, vis_char_img) + + +def visualize_point_result(im_fn, point_list, point_pair_list, src_im, gt_dir, + result_path): + """ + """ + im_basename = os.path.basename(im_fn) + im_prefix = im_basename[:im_basename.rfind('.')] + vis_det_img = src_im.copy() + + # draw gt bbox on the image. + text_path = os.path.join(gt_dir, im_prefix + '.txt') + fid = open(text_path, 'r') + lines = [line.strip() for line in fid.readlines()] + for line in lines: + tokens = line.strip().split('\t') + coords = tokens[0].split(',') + coords_len = len(coords) + coords = list(map(float, coords)) + gt_poly = np.array(coords).reshape(1, coords_len / 2, 2) + cv2.polylines( + vis_det_img, + np.array(gt_poly).astype(np.int32), + isClosed=True, + color=(255, 255, 255), + thickness=1) + + for point, point_pair in zip(point_list, point_pair_list): + cv2.line( + vis_det_img, + tuple(point_pair[0]), + tuple(point_pair[1]), (0, 255, 255), + thickness=1) + cv2.circle(vis_det_img, tuple(point), 2, (0, 0, 255)) + cv2.circle(vis_det_img, tuple(point_pair[0]), 2, (255, 0, 0)) + cv2.circle(vis_det_img, tuple(point_pair[1]), 2, (0, 255, 0)) + + if not os.path.exists(result_path): + os.makedirs(result_path) + cv2.imwrite("{}/{}_border_points.jpg".format(result_path, im_prefix), + vis_det_img) + + +def resize_image(im, max_side_len=512): + """ + resize image to a size multiple of max_stride which is required by the network + :param im: the resized image + :param max_side_len: limit of max image size to avoid out of memory in gpu + :return: the resized image and the resize ratio + """ + h, w, _ = im.shape + + resize_w = w + resize_h = h + + # Fix the longer side + if resize_h > resize_w: + ratio = float(max_side_len) / resize_h + else: + ratio = float(max_side_len) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return im, (ratio_h, ratio_w) + + +def resize_image_min(im, max_side_len=512): + """ + """ + print('--> Using resize_image_min') + h, w, _ = im.shape + + resize_w = w + resize_h = h + + # Fix the longer side + if resize_h < resize_w: + ratio = float(max_side_len) / resize_h + else: + ratio = float(max_side_len) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + return im, (ratio_h, ratio_w) + + +def resize_image_for_totaltext(im, max_side_len=512): + """ + """ + h, w, _ = im.shape + + resize_w = w + resize_h = h + ratio = 1.25 + if h * ratio > max_side_len: + ratio = float(max_side_len) / resize_h + # Fix the longer side + # if resize_h > resize_w: + # ratio = float(max_side_len) / resize_h + # else: + # ratio = float(max_side_len) / resize_w + ### + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + return im, (ratio_h, ratio_w) + + +def point_pair2poly(point_pair_list): + """ + Transfer vertical point_pairs into poly point in clockwise. + """ + pair_length_list = [] + for point_pair in point_pair_list: + pair_length = np.linalg.norm(point_pair[0] - point_pair[1]) + pair_length_list.append(pair_length) + pair_length_list = np.array(pair_length_list) + pair_info = (pair_length_list.max(), pair_length_list.min(), + pair_length_list.mean()) + + # constract poly + point_num = len(point_pair_list) * 2 + point_list = [0] * point_num + for idx, point_pair in enumerate(point_pair_list): + point_list[idx] = point_pair[0] + point_list[point_num - 1 - idx] = point_pair[1] + return np.array(point_list).reshape(-1, 2), pair_info + + +def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.): + """ + Generate shrink_quad_along_width. + """ + ratio_pair = np.array( + [[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair + p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair + return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) + + +def expand_poly_along_width(poly, shrink_ratio_of_width=0.3): + """ + expand poly along width. + """ + point_num = poly.shape[0] + left_quad = np.array( + [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) + left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \ + (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) + left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0) + right_quad = np.array( + [ + poly[point_num // 2 - 2], poly[point_num // 2 - 1], + poly[point_num // 2], poly[point_num // 2 + 1] + ], + dtype=np.float32) + right_ratio = 1.0 + \ + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ + (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) + right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio) + poly[0] = left_quad_expand[0] + poly[-1] = left_quad_expand[-1] + poly[point_num // 2 - 1] = right_quad_expand[1] + poly[point_num // 2] = right_quad_expand[2] + return poly + + +def norm2(x, axis=None): + if axis: + return np.sqrt(np.sum(x**2, axis=axis)) + return np.sqrt(np.sum(x**2)) + + +def cos(p1, p2): + return (p1 * p2).sum() / (norm2(p1) * norm2(p2)) + + +def generate_direction_info(image_fn, + H, + W, + ratio_h, + ratio_w, + max_length=640, + out_scale=4, + gt_dir=None): + """ + """ + im_basename = os.path.basename(image_fn) + im_prefix = im_basename[:im_basename.rfind('.')] + instance_direction_map = np.zeros(shape=[H // out_scale, W // out_scale, 3]) + + if gt_dir is None: + gt_dir = '/home/vis/huangzuming/data/SYNTH_DATA/part_vgg_synth_icdar/processed/val/poly' + + # get gt label map + text_path = os.path.join(gt_dir, im_prefix + '.txt') + fid = open(text_path, 'r') + lines = [line.strip() for line in fid.readlines()] + for label_idx, line in enumerate(lines, start=1): + coords, txt = line.strip().split('\t') + if txt == '###': + continue + tokens = coords.strip().split(',') + coords = list(map(float, tokens)) + poly = np.array(coords).reshape(4, 2) * np.array( + [ratio_w, ratio_h]).reshape(1, 2) / out_scale + mid_idx = poly.shape[0] // 2 + direct_vector = ( + (poly[mid_idx] + poly[mid_idx - 1]) - (poly[0] + poly[-1])) / 2.0 + + direct_vector /= len(txt) + # l2_distance = norm2(direct_vector) + # avg_char_distance = l2_distance / len(txt) + avg_char_distance = 1.0 + + direct_label = (direct_vector[0], direct_vector[1], avg_char_distance) + cv2.fillPoly(instance_direction_map, + poly.round().astype(np.int32)[np.newaxis, :, :], + direct_label) + instance_direction_map = instance_direction_map.transpose([2, 0, 1]) + return instance_direction_map[:2, ...] diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py new file mode 100755 index 00000000..9dcde8a9 --- /dev/null +++ b/tools/infer_e2e.py @@ -0,0 +1,114 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import json +import paddle + +from ppocr.data import create_operators, transform +from ppocr.modeling.architectures import build_model +from ppocr.postprocess import build_post_process +from ppocr.utils.save_load import init_model +from ppocr.utils.utility import get_image_file_list +import tools.program as program + + +def draw_e2e_res(dt_boxes, strs, config, img, img_name): + if len(dt_boxes) > 0: + src_im = img + for box, str in zip(dt_boxes, strs): + box = box.astype(np.int32).reshape((-1, 1, 2)) + cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) + cv2.putText(src_im, str, org=(int(box[0, 0, 0]), int(box[0, 0, 1])), + fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=0.7, color=(0, 255, 0), thickness=1) + save_det_path = os.path.dirname(config['Global'][ + 'save_res_path']) + "/e2e_results/" + if not os.path.exists(save_det_path): + os.makedirs(save_det_path) + save_path = os.path.join(save_det_path, os.path.basename(img_name)) + cv2.imwrite(save_path, src_im) + logger.info("The e2e Image saved in {}".format(save_path)) + +def main(): + global_config = config['Global'] + + # build model + model = build_model(config['Architecture']) + + init_model(config, model, logger) + + # build post process + post_process_class = build_post_process(config['PostProcess']) + + # create data ops + transforms = [] + for op in config['Eval']['dataset']['transforms']: + op_name = list(op)[0] + if 'Label' in op_name: + continue + elif op_name == 'KeepKeys': + op[op_name]['keep_keys'] = ['image', 'shape'] + transforms.append(op) + + ops = create_operators(transforms, global_config) + + save_res_path = config['Global']['save_res_path'] + if not os.path.exists(os.path.dirname(save_res_path)): + os.makedirs(os.path.dirname(save_res_path)) + + model.eval() + with open(save_res_path, "wb") as fout: + for file in get_image_file_list(config['Global']['infer_img']): + logger.info("infer_img: {}".format(file)) + with open(file, 'rb') as f: + img = f.read() + data = {'image': img} + batch = transform(data, ops) + images = np.expand_dims(batch[0], axis=0) + shape_list = np.expand_dims(batch[1], axis=0) + images = paddle.to_tensor(images) + preds = model(images) + post_result = post_process_class(preds, shape_list) + points, strs = post_result['points'], post_result['strs'] + # write resule + dt_boxes_json = [] + for poly, str in zip(points, strs): + tmp_json = {"transcription": str} + tmp_json['points'] = poly.tolist() + dt_boxes_json.append(tmp_json) + otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n" + fout.write(otstr.encode()) + src_img = cv2.imread(file) + draw_e2e_res(points, strs, config, src_img, file) + logger.info("success!") + + +if __name__ == '__main__': + config, device, logger, vdl_writer = program.preprocess() + main() \ No newline at end of file diff --git a/tools/program.py b/tools/program.py index ae649176..778af8ec 100755 --- a/tools/program.py +++ b/tools/program.py @@ -44,6 +44,7 @@ class ArgsParser(ArgumentParser): def parse_args(self, argv=None): args = super(ArgsParser, self).parse_args(argv) + args.config = '/Users/hongyongjie/project/PaddleOCR/configs/e2e/e2e_r50_vd_pg.yml' assert args.config is not None, \ "Please specify --config=configure_file_path." args.opt = self._parse_opt(args.opt) @@ -374,7 +375,8 @@ def preprocess(is_train=False): alg = config['Architecture']['algorithm'] assert alg in [ - 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS' + 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', + 'CLS', 'PG' ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' -- GitLab