From bb49e1a53f34490f884a4e287a4c9731c7d61828 Mon Sep 17 00:00:00 2001 From: Jethong <1147925384@qq.com> Date: Mon, 8 Mar 2021 15:11:57 +0800 Subject: [PATCH] ADD PGNet_v2 --- ppocr/data/imaug/label_ops.py | 3 +- ppocr/metrics/e2e_metric.py | 17 - ppocr/modeling/necks/pg_fpn.py | 280 ++++--- ppocr/utils/e2e_metric/Deteval.py | 30 +- ppocr/utils/e2e_metric/polygon_fast.py | 14 +- ppocr/utils/e2e_metric/tttt.py | 881 --------------------- ppocr/utils/e2e_utils/extract_textpoint.py | 13 + ppocr/utils/e2e_utils/ski_thin.py | 16 +- ppocr/utils/e2e_utils/visual.py | 198 +---- tools/program.py | 1 - 10 files changed, 227 insertions(+), 1226 deletions(-) delete mode 100644 ppocr/utils/e2e_metric/tttt.py diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 4cae2337..3ae22b40 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -37,6 +37,7 @@ class ClsLabelEncode(object): class E2ELabelEncode(object): def __init__(self, label_list, **kwargs): self.label_list = label_list + self.max_len = 50 def __call__(self, data): text_label_index_list, temp_text = [], [] @@ -47,7 +48,7 @@ class E2ELabelEncode(object): for c_ in text: if c_ in self.label_list: temp_text.append(self.label_list.index(c_)) - temp_text = temp_text + [36] * (50 - len(temp_text)) + temp_text = temp_text + [36] * (self.max_len - len(temp_text)) text_label_index_list.append(temp_text) data['strs'] = np.array(text_label_index_list) return data diff --git a/ppocr/metrics/e2e_metric.py b/ppocr/metrics/e2e_metric.py index 6901187a..45248b91 100644 --- a/ppocr/metrics/e2e_metric.py +++ b/ppocr/metrics/e2e_metric.py @@ -32,16 +32,6 @@ class E2EMetric(object): self.reset() def __call__(self, preds, batch, **kwargs): - ''' - batch: a list produced by dataloaders. - image: np.ndarray of shape (N, C, H, W). - ratio_list: np.ndarray of shape(N,2) - polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions. - ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not. - preds: a list of dict produced by post process - points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions. - ''' - gt_polyons_batch = batch[2] temp_gt_strs_batch = batch[3] ignore_tags_batch = batch[4] @@ -72,13 +62,6 @@ class E2EMetric(object): self.results.append(result) def get_metric(self): - """ - return metrics { - 'precision': 0, - 'recall': 0, - 'hmean': 0 - } - """ metircs = combine_results(self.results) self.reset() return metircs diff --git a/ppocr/modeling/necks/pg_fpn.py b/ppocr/modeling/necks/pg_fpn.py index 9bd560c9..ba14c1b2 100644 --- a/ppocr/modeling/necks/pg_fpn.py +++ b/ppocr/modeling/necks/pg_fpn.py @@ -106,172 +106,212 @@ class DeConvBNLayer(nn.Layer): return x -class FPN_Up_Fusion(nn.Layer): - def __init__(self, in_channels): - super(FPN_Up_Fusion, self).__init__() - in_channels = in_channels[::-1] - out_channels = [256, 256, 192, 192, 128] +class PGFPN(nn.Layer): + def __init__(self, in_channels, **kwargs): + super(PGFPN, self).__init__() + num_inputs = [2048, 2048, 1024, 512, 256] + num_outputs = [256, 256, 192, 192, 128] + self.out_channels = 128 + # print(in_channels) + self.conv_bn_layer_1 = ConvBNLayer( + in_channels=3, + out_channels=32, + kernel_size=3, + stride=1, + act=None, + name='FPN_d1') + self.conv_bn_layer_2 = ConvBNLayer( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + act=None, + name='FPN_d2') + self.conv_bn_layer_3 = ConvBNLayer( + in_channels=256, + out_channels=128, + kernel_size=3, + stride=1, + act=None, + name='FPN_d3') + self.conv_bn_layer_4 = ConvBNLayer( + in_channels=32, + out_channels=64, + kernel_size=3, + stride=2, + act=None, + name='FPN_d4') + self.conv_bn_layer_5 = ConvBNLayer( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + act='relu', + name='FPN_d5') + self.conv_bn_layer_6 = ConvBNLayer( + in_channels=64, + out_channels=128, + kernel_size=3, + stride=2, + act=None, + name='FPN_d6') + self.conv_bn_layer_7 = ConvBNLayer( + in_channels=128, + out_channels=128, + kernel_size=3, + stride=1, + act='relu', + name='FPN_d7') + self.conv_bn_layer_8 = ConvBNLayer( + in_channels=128, + out_channels=128, + kernel_size=1, + stride=1, + act=None, + name='FPN_d8') - self.h0_conv = ConvBNLayer( - in_channels[0], out_channels[0], 1, 1, act=None, name='conv_h0') - self.h1_conv = ConvBNLayer( - in_channels[1], out_channels[1], 1, 1, act=None, name='conv_h1') - self.h2_conv = ConvBNLayer( - in_channels[2], out_channels[2], 1, 1, act=None, name='conv_h2') - self.h3_conv = ConvBNLayer( - in_channels[3], out_channels[3], 1, 1, act=None, name='conv_h3') - self.h4_conv = ConvBNLayer( - in_channels[4], out_channels[4], 1, 1, act=None, name='conv_h4') + self.conv_h0 = ConvBNLayer( + in_channels=num_inputs[0], + out_channels=num_outputs[0], + kernel_size=1, + stride=1, + act=None, + name="conv_h{}".format(0)) + self.conv_h1 = ConvBNLayer( + in_channels=num_inputs[1], + out_channels=num_outputs[1], + kernel_size=1, + stride=1, + act=None, + name="conv_h{}".format(1)) + self.conv_h2 = ConvBNLayer( + in_channels=num_inputs[2], + out_channels=num_outputs[2], + kernel_size=1, + stride=1, + act=None, + name="conv_h{}".format(2)) + self.conv_h3 = ConvBNLayer( + in_channels=num_inputs[3], + out_channels=num_outputs[3], + kernel_size=1, + stride=1, + act=None, + name="conv_h{}".format(3)) + self.conv_h4 = ConvBNLayer( + in_channels=num_inputs[4], + out_channels=num_outputs[4], + kernel_size=1, + stride=1, + act=None, + name="conv_h{}".format(4)) self.dconv0 = DeConvBNLayer( - in_channels=out_channels[0], - out_channels=out_channels[1], + in_channels=num_outputs[0], + out_channels=num_outputs[0 + 1], name="dconv_{}".format(0)) self.dconv1 = DeConvBNLayer( - in_channels=out_channels[1], - out_channels=out_channels[2], + in_channels=num_outputs[1], + out_channels=num_outputs[1 + 1], act=None, name="dconv_{}".format(1)) self.dconv2 = DeConvBNLayer( - in_channels=out_channels[2], - out_channels=out_channels[3], + in_channels=num_outputs[2], + out_channels=num_outputs[2 + 1], act=None, name="dconv_{}".format(2)) self.dconv3 = DeConvBNLayer( - in_channels=out_channels[3], - out_channels=out_channels[4], + in_channels=num_outputs[3], + out_channels=num_outputs[3 + 1], act=None, name="dconv_{}".format(3)) self.conv_g1 = ConvBNLayer( - in_channels=out_channels[1], - out_channels=out_channels[1], + in_channels=num_outputs[1], + out_channels=num_outputs[1], kernel_size=3, stride=1, act='relu', name="conv_g{}".format(1)) self.conv_g2 = ConvBNLayer( - in_channels=out_channels[2], - out_channels=out_channels[2], + in_channels=num_outputs[2], + out_channels=num_outputs[2], kernel_size=3, stride=1, act='relu', name="conv_g{}".format(2)) self.conv_g3 = ConvBNLayer( - in_channels=out_channels[3], - out_channels=out_channels[3], + in_channels=num_outputs[3], + out_channels=num_outputs[3], kernel_size=3, stride=1, act='relu', name="conv_g{}".format(3)) self.conv_g4 = ConvBNLayer( - in_channels=out_channels[4], - out_channels=out_channels[4], + in_channels=num_outputs[4], + out_channels=num_outputs[4], kernel_size=3, stride=1, act='relu', name="conv_g{}".format(4)) self.convf = ConvBNLayer( - in_channels=out_channels[4], - out_channels=out_channels[4], + in_channels=num_outputs[4], + out_channels=num_outputs[4], kernel_size=1, stride=1, act=None, name="conv_f{}".format(4)) - def _add_relu(self, x1, x2): - x = paddle.add(x=x1, y=x2) - x = F.relu(x) - return x - def forward(self, x): - f = x[2:][::-1] - h0 = self.h0_conv(f[0]) - h1 = self.h1_conv(f[1]) - h2 = self.h2_conv(f[2]) - h3 = self.h3_conv(f[3]) - h4 = self.h4_conv(f[4]) + c0, c1, c2, c3, c4, c5, c6 = x + # FPN_Down_Fusion + f = [c0, c1, c2] + g = [None, None, None] + h = [None, None, None] + h[0] = self.conv_bn_layer_1(f[0]) + h[1] = self.conv_bn_layer_2(f[1]) + h[2] = self.conv_bn_layer_3(f[2]) - g0 = self.dconv0(h0) + g[0] = self.conv_bn_layer_4(h[0]) + g[1] = paddle.add(g[0], h[1]) + g[1] = F.relu(g[1]) + g[1] = self.conv_bn_layer_5(g[1]) + g[1] = self.conv_bn_layer_6(g[1]) - g1 = self.dconv2(self.conv_g2(self._add_relu(g0, h1))) - g2 = self.dconv2(self.conv_g2(self._add_relu(g1, h2))) - g3 = self.dconv3(self.conv_g2(self._add_relu(g2, h3))) - g4 = self.dconv4(self.conv_g2(self._add_relu(g3, h4))) - return g4 + g[2] = paddle.add(g[1], h[2]) + g[2] = F.relu(g[2]) + g[2] = self.conv_bn_layer_7(g[2]) + f_down = self.conv_bn_layer_8(g[2]) + # FPN UP Fusion + f1 = [c6, c5, c4, c3, c2] + g = [None, None, None, None, None] + h = [None, None, None, None, None] + h[0] = self.conv_h0(f1[0]) + h[1] = self.conv_h1(f1[1]) + h[2] = self.conv_h2(f1[2]) + h[3] = self.conv_h3(f1[3]) + h[4] = self.conv_h4(f1[4]) -class FPN_Down_Fusion(nn.Layer): - def __init__(self, in_channels): - super(FPN_Down_Fusion, self).__init__() - out_channels = [32, 64, 128] + g[0] = self.dconv0(h[0]) + g[1] = paddle.add(g[0], h[1]) + g[1] = F.relu(g[1]) + g[1] = self.conv_g1(g[1]) + g[1] = self.dconv1(g[1]) - self.h0_conv = ConvBNLayer( - in_channels[0], out_channels[0], 3, 1, act=None, name='FPN_d1') - self.h1_conv = ConvBNLayer( - in_channels[1], out_channels[1], 3, 1, act=None, name='FPN_d2') - self.h2_conv = ConvBNLayer( - in_channels[2], out_channels[2], 3, 1, act=None, name='FPN_d3') + g[2] = paddle.add(g[1], h[2]) + g[2] = F.relu(g[2]) + g[2] = self.conv_g2(g[2]) + g[2] = self.dconv2(g[2]) - self.g0_conv = ConvBNLayer( - out_channels[0], out_channels[1], 3, 2, act=None, name='FPN_d4') + g[3] = paddle.add(g[2], h[3]) + g[3] = F.relu(g[3]) + g[3] = self.conv_g3(g[3]) + g[3] = self.dconv3(g[3]) - self.g1_conv = nn.Sequential( - ConvBNLayer( - out_channels[1], - out_channels[1], - 3, - 1, - act='relu', - name='FPN_d5'), - ConvBNLayer( - out_channels[1], out_channels[2], 3, 2, act=None, - name='FPN_d6')) - - self.g2_conv = nn.Sequential( - ConvBNLayer( - out_channels[2], - out_channels[2], - 3, - 1, - act='relu', - name='FPN_d7'), - ConvBNLayer( - out_channels[2], out_channels[2], 1, 1, act=None, - name='FPN_d8')) - - def forward(self, x): - f = x[:3] - h0 = self.h0_conv(f[0]) - h1 = self.h1_conv(f[1]) - h2 = self.h2_conv(f[2]) - g0 = self.g0_conv(h0) - g1 = paddle.add(x=g0, y=h1) - g1 = F.relu(g1) - g1 = self.g1_conv(g1) - g2 = paddle.add(x=g1, y=h2) - g2 = F.relu(g2) - g2 = self.g2_conv(g2) - return g2 - - -class PGFPN(nn.Layer): - def __init__(self, in_channels, with_cab=False, **kwargs): - super(PGFPN, self).__init__() - self.in_channels = in_channels - self.with_cab = with_cab - self.FPN_Down_Fusion = FPN_Down_Fusion(self.in_channels) - self.FPN_Up_Fusion = FPN_Up_Fusion(self.in_channels) - self.out_channels = 128 - - def forward(self, x): - # down fpn - f_down = self.FPN_Down_Fusion(x) - - # up fpn - f_up = self.FPN_Up_Fusion(x) - - # fusion - f_common = paddle.add(x=f_down, y=f_up) + g[4] = paddle.add(x=g[3], y=h[4]) + g[4] = F.relu(g[4]) + g[4] = self.conv_g4(g[4]) + f_up = self.convf(g[4]) + f_common = paddle.add(f_down, f_up) f_common = F.relu(f_common) - return f_common diff --git a/ppocr/utils/e2e_metric/Deteval.py b/ppocr/utils/e2e_metric/Deteval.py index fd12ecab..8337e539 100755 --- a/ppocr/utils/e2e_metric/Deteval.py +++ b/ppocr/utils/e2e_metric/Deteval.py @@ -1,9 +1,18 @@ -from os import listdir -import os, sys -from scipy import io +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import numpy as np from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area -from tqdm import tqdm try: # python2 range = xrange @@ -862,16 +871,3 @@ def combine_results(all_data): 'f_score_e2e': f_score_e2e } return final - - -# a = [1526, 642, 1565, 629, 1579, 627, 1593, 625, 1607, 623, 1620, 622, 1634, 620, 1659, 620, 1654, 681, 1631, 680, 1618, -# 681, 1606, 681, 1594, 681, 1584, 682, 1573, 685, 1542, 694] -# gt_dict = [{'points': np.array(a).reshape(-1, 2), 'text': 'MILK'}] -# pred_dict = [{'points': np.array(a), 'text': 'ccc'}, -# {'points': np.array(a), 'text': 'ccf'}] -# result = [] -# for i in range(2): -# result.append(get_socre(gt_dict, pred_dict)) -# print(111) -# a = combine_results(result) -# print(a) diff --git a/ppocr/utils/e2e_metric/polygon_fast.py b/ppocr/utils/e2e_metric/polygon_fast.py index 0173212e..c78e2a1e 100755 --- a/ppocr/utils/e2e_metric/polygon_fast.py +++ b/ppocr/utils/e2e_metric/polygon_fast.py @@ -1,6 +1,18 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import numpy as np from shapely.geometry import Polygon -#import Polygon """ :param det_x: [1, N] Xs of detection's vertices :param det_y: [1, N] Ys of detection's vertices diff --git a/ppocr/utils/e2e_metric/tttt.py b/ppocr/utils/e2e_metric/tttt.py deleted file mode 100644 index 91d893fd..00000000 --- a/ppocr/utils/e2e_metric/tttt.py +++ /dev/null @@ -1,881 +0,0 @@ -from os import listdir -import os, sys -from scipy import io -import numpy as np -from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area -from tqdm import tqdm - -try: # python2 - range = xrange -except Exception: - # python3 - range = range -""" -Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('\n')' -""" - -# if len(sys.argv) != 4: -# print('\n usage: test.py pred_dir gt_dir savefile') -# sys.exit() -global_tp = 0 -global_fp = 0 -global_fn = 0 - -tr = 0.7 -tp = 0.6 -fsc_k = 0.8 -k = 2 - - -def get_socre(gt_dict, pred_dict): - # allInputs = listdir(input_dir) - allInputs = 1 - global_pred_str = [] - global_gt_str = [] - global_sigma = [] - global_tau = [] - - def input_reading_mod(pred_dict, input): - """This helper reads input from txt files""" - det = [] - n = len(pred_dict) - for i in range(n): - points = pred_dict[i]['points'] - text = pred_dict[i]['text'] - # for i in range(len(points)): - point = ",".join(map(str, points.reshape(-1, ))) - det.append([point, text]) - return det - - def gt_reading_mod(gt_dict, gt_id): - """This helper reads groundtruths from mat files""" - # gt_id = gt_id.split('.')[0] - gt = [] - n = len(gt_dict) - for i in range(n): - points = gt_dict[i]['points'].tolist() - h = len(points) - text = gt_dict[i]['text'] - xx = [ - np.array( - ['x:'], dtype=' 1): - gt_x = list(map(int, np.squeeze(gt[1]))) - gt_y = list(map(int, np.squeeze(gt[3]))) - for det_id, detection in enumerate(detections): - detection_orig = detection - detection = [float(x) for x in detection[0].split(',')] - # detection = detection.split(',') - detection = list(map(int, detection)) - det_x = detection[0::2] - det_y = detection[1::2] - det_gt_iou = iod(det_x, det_y, gt_x, gt_y) - if det_gt_iou > threshold: - detections[det_id] = [] - - detections[:] = [item for item in detections if item != []] - return detections - - def sigma_calculation(det_x, det_y, gt_x, gt_y): - """ - sigma = inter_area / gt_area - """ - # print(area_of_intersection(det_x, det_y, gt_x, gt_y)) - return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / - area(gt_x, gt_y)), 2) - - def tau_calculation(det_x, det_y, gt_x, gt_y): - """ - tau = inter_area / det_area - """ - # print "liushanshan det_x {}".format(det_x) - # print "liushanshan det_y {}".format(det_y) - # print "liushanshan area {}".format(area(det_x, det_y)) - # print "liushanshan tau = {}".format(np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / area(det_x, det_y)), 2)) - if area(det_x, det_y) == 0.0: - return 0 - return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / - area(det_x, det_y)), 2) - - ##############################Initialization################################### - - ############################################################################### - single_data = {} - for input_id in range(allInputs): - - if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and ( - input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and ( - input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \ - and (input_id != 'Deteval_result_non_curved.txt'): - print(input_id) - detections = input_reading_mod(pred_dict, input_id) - # print "liushanshan detections = {}".format(detections) - groundtruths = gt_reading_mod(gt_dict, input_id) - detections = detection_filtering( - detections, - groundtruths) # filters detections overlapping with DC area - dc_id = [] - for i in range(len(groundtruths)): - if groundtruths[i][5] == '#': - dc_id.append(i) - cnt = 0 - for a in dc_id: - num = a - cnt - del groundtruths[num] - cnt += 1 - - local_sigma_table = np.zeros((len(groundtruths), len(detections))) - local_tau_table = np.zeros((len(groundtruths), len(detections))) - local_pred_str = {} - local_gt_str = {} - - for gt_id, gt in enumerate(groundtruths): - if len(detections) > 0: - for det_id, detection in enumerate(detections): - detection_orig = detection - detection = [float(x) for x in detection[0].split(',')] - detection = list(map(int, detection)) - pred_seq_str = detection_orig[1].strip() - det_x = detection[0::2] - det_y = detection[1::2] - gt_x = list(map(int, np.squeeze(gt[1]))) - gt_y = list(map(int, np.squeeze(gt[3]))) - gt_seq_str = str(gt[4].tolist()[0]) - - local_sigma_table[gt_id, det_id] = sigma_calculation( - det_x, det_y, gt_x, gt_y) - local_tau_table[gt_id, det_id] = tau_calculation( - det_x, det_y, gt_x, gt_y) - local_pred_str[det_id] = pred_seq_str - local_gt_str[gt_id] = gt_seq_str - - global_sigma.append(local_sigma_table) - global_tau.append(local_tau_table) - global_pred_str.append(local_pred_str) - global_gt_str.append(local_gt_str) - print - "liushanshan global_pred_str = {}".format(global_pred_str) - print - "liushanshan global_gt_str = {}".format(global_gt_str) - single_data['sigma'] = global_sigma - single_data['global_tau'] = global_tau - single_data['global_pred_str'] = global_pred_str - single_data['global_gt_str'] = global_gt_str - return single_data - - -def combine_results(all_data): - global_sigma, global_tau, global_pred_str, global_gt_str = [], [], [], [] - for data in all_data: - global_sigma.append(data['sigma']) - global_tau.append(data['global_tau']) - global_pred_str.append(data['global_pred_str']) - global_gt_str.append(data['global_gt_str']) - global_accumulative_recall = 0 - global_accumulative_precision = 0 - total_num_gt = 0 - total_num_det = 0 - hit_str_count = 0 - hit_count = 0 - - def one_to_one(local_sigma_table, local_tau_table, - local_accumulative_recall, local_accumulative_precision, - global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idy): - hit_str_num = 0 - for gt_id in range(num_gt): - gt_matching_qualified_sigma_candidates = np.where( - local_sigma_table[gt_id, :] > tr) - gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[ - 0].shape[0] - gt_matching_qualified_tau_candidates = np.where( - local_tau_table[gt_id, :] > tp) - gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[ - 0].shape[0] - - det_matching_qualified_sigma_candidates = np.where( - local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]] - > tr) - det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[ - 0].shape[0] - det_matching_qualified_tau_candidates = np.where( - local_tau_table[:, gt_matching_qualified_tau_candidates[0]] > - tp) - det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[ - 0].shape[0] - - if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \ - (det_matching_num_qualified_sigma_candidates == 1) and ( - det_matching_num_qualified_tau_candidates == 1): - global_accumulative_recall = global_accumulative_recall + 1.0 - global_accumulative_precision = global_accumulative_precision + 1.0 - local_accumulative_recall = local_accumulative_recall + 1.0 - local_accumulative_precision = local_accumulative_precision + 1.0 - - gt_flag[0, gt_id] = 1 - matched_det_id = np.where(local_sigma_table[gt_id, :] > tr) - # recg start - print - "liushanshan one to one det_id = {}".format(matched_det_id) - print - "liushanshan one to one gt_id = {}".format(gt_id) - gt_str_cur = global_gt_str[idy][gt_id] - pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[ - 0]] - print - "liushanshan one to one gt_str_cur = {}".format(gt_str_cur) - print - "liushanshan one to one pred_str_cur = {}".format(pred_str_cur) - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - else: - if pred_str_cur.lower() == gt_str_cur.lower(): - hit_str_num += 1 - # recg end - det_flag[0, matched_det_id] = 1 - return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num - - def one_to_many(local_sigma_table, local_tau_table, - local_accumulative_recall, local_accumulative_precision, - global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idy): - hit_str_num = 0 - for gt_id in range(num_gt): - # skip the following if the groundtruth was matched - if gt_flag[0, gt_id] > 0: - continue - - non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0) - num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0] - - if num_non_zero_in_sigma >= k: - ####search for all detections that overlaps with this groundtruth - qualified_tau_candidates = np.where((local_tau_table[ - gt_id, :] >= tp) & (det_flag[0, :] == 0)) - num_qualified_tau_candidates = qualified_tau_candidates[ - 0].shape[0] - - if num_qualified_tau_candidates == 1: - if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp) - and - (local_sigma_table[gt_id, qualified_tau_candidates] >= - tr)): - # became an one-to-one case - global_accumulative_recall = global_accumulative_recall + 1.0 - global_accumulative_precision = global_accumulative_precision + 1.0 - local_accumulative_recall = local_accumulative_recall + 1.0 - local_accumulative_precision = local_accumulative_precision + 1.0 - - gt_flag[0, gt_id] = 1 - det_flag[0, qualified_tau_candidates] = 1 - # recg start - print - "liushanshan one to many det_id = {}".format( - qualified_tau_candidates) - print - "liushanshan one to many gt_id = {}".format(gt_id) - gt_str_cur = global_gt_str[idy][gt_id] - pred_str_cur = global_pred_str[idy][ - qualified_tau_candidates[0].tolist()[0]] - print - "liushanshan one to many gt_str_cur = {}".format( - gt_str_cur) - print - "liushanshan one to many pred_str_cur = {}".format( - pred_str_cur) - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - else: - if pred_str_cur.lower() == gt_str_cur.lower(): - hit_str_num += 1 - # recg end - elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) - >= tr): - gt_flag[0, gt_id] = 1 - det_flag[0, qualified_tau_candidates] = 1 - # recg start - print - "liushanshan one to many det_id = {}".format( - qualified_tau_candidates) - print - "liushanshan one to many gt_id = {}".format(gt_id) - gt_str_cur = global_gt_str[idy][gt_id] - pred_str_cur = global_pred_str[idy][ - qualified_tau_candidates[0].tolist()[0]] - print - "liushanshan one to many gt_str_cur = {}".format(gt_str_cur) - print - "liushanshan one to many pred_str_cur = {}".format( - pred_str_cur) - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - else: - if pred_str_cur.lower() == gt_str_cur.lower(): - hit_str_num += 1 - # recg end - - global_accumulative_recall = global_accumulative_recall + fsc_k - global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k - - local_accumulative_recall = local_accumulative_recall + fsc_k - local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k - - return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num - - def many_to_one(local_sigma_table, local_tau_table, - local_accumulative_recall, local_accumulative_precision, - global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idy): - hit_str_num = 0 - for det_id in range(num_det): - # skip the following if the detection was matched - if det_flag[0, det_id] > 0: - continue - - non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0) - num_non_zero_in_tau = non_zero_in_tau[0].shape[0] - - if num_non_zero_in_tau >= k: - ####search for all detections that overlaps with this groundtruth - qualified_sigma_candidates = np.where(( - local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0)) - num_qualified_sigma_candidates = qualified_sigma_candidates[ - 0].shape[0] - - if num_qualified_sigma_candidates == 1: - if ((local_tau_table[qualified_sigma_candidates, det_id] >= - tp) and - (local_sigma_table[qualified_sigma_candidates, det_id] - >= tr)): - # became an one-to-one case - global_accumulative_recall = global_accumulative_recall + 1.0 - global_accumulative_precision = global_accumulative_precision + 1.0 - local_accumulative_recall = local_accumulative_recall + 1.0 - local_accumulative_precision = local_accumulative_precision + 1.0 - - gt_flag[0, qualified_sigma_candidates] = 1 - det_flag[0, det_id] = 1 - # recg start - print - "liushanshan many to one det_id = {}".format(det_id) - print - "liushanshan many to one gt_id = {}".format( - qualified_sigma_candidates) - pred_str_cur = global_pred_str[idy][det_id] - gt_len = len(qualified_sigma_candidates[0]) - for idx in range(gt_len): - ele_gt_id = qualified_sigma_candidates[0].tolist()[ - idx] - if not global_gt_str[idy].has_key(ele_gt_id): - continue - gt_str_cur = global_gt_str[idy][ele_gt_id] - print - "liushanshan many to one gt_str_cur = {}".format( - gt_str_cur) - print - "liushanshan many to one pred_str_cur = {}".format( - pred_str_cur) - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - break - else: - if pred_str_cur.lower() == gt_str_cur.lower(): - hit_str_num += 1 - break - # recg end - elif (np.sum(local_tau_table[qualified_sigma_candidates, - det_id]) >= tp): - det_flag[0, det_id] = 1 - gt_flag[0, qualified_sigma_candidates] = 1 - # recg start - print - "liushanshan many to one det_id = {}".format(det_id) - print - "liushanshan many to one gt_id = {}".format( - qualified_sigma_candidates) - pred_str_cur = global_pred_str[idy][det_id] - gt_len = len(qualified_sigma_candidates[0]) - for idx in range(gt_len): - ele_gt_id = qualified_sigma_candidates[0].tolist()[idx] - if not global_gt_str[idy].has_key(ele_gt_id): - continue - gt_str_cur = global_gt_str[idy][ele_gt_id] - print - "liushanshan many to one gt_str_cur = {}".format( - gt_str_cur) - print - "liushanshan many to one pred_str_cur = {}".format( - pred_str_cur) - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - break - else: - if pred_str_cur.lower() == gt_str_cur.lower(): - hit_str_num += 1 - break - else: - print - 'no match' - # recg end - - global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k - global_accumulative_precision = global_accumulative_precision + fsc_k - - local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k - local_accumulative_precision = local_accumulative_precision + fsc_k - return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num - - for idx in range(len(global_sigma)): - # print(allInputs[idx]) - local_sigma_table = np.array(global_sigma[idx]) - local_tau_table = global_tau[idx] - - num_gt = local_sigma_table.shape[0] - num_det = local_sigma_table.shape[1] - - total_num_gt = total_num_gt + num_gt - total_num_det = total_num_det + num_det - - local_accumulative_recall = 0 - local_accumulative_precision = 0 - gt_flag = np.zeros((1, num_gt)) - det_flag = np.zeros((1, num_det)) - - #######first check for one-to-one case########## - local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ - gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table, - local_accumulative_recall, local_accumulative_precision, - global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idx) - - hit_str_count += hit_str_num - #######then check for one-to-many case########## - local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ - gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table, - local_accumulative_recall, local_accumulative_precision, - global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idx) - hit_str_count += hit_str_num - #######then check for many-to-one case########## - local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ - gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table, - local_accumulative_recall, local_accumulative_precision, - global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idx) - - try: - recall = global_accumulative_recall / total_num_gt - except ZeroDivisionError: - recall = 0 - - try: - precision = global_accumulative_precision / total_num_det - except ZeroDivisionError: - precision = 0 - - try: - f_score = 2 * precision * recall / (precision + recall) - except ZeroDivisionError: - f_score = 0 - - try: - seqerr = 1 - float(hit_str_count) / global_accumulative_recall - except ZeroDivisionError: - seqerr = 1 - - try: - recall_e2e = float(hit_str_count) / total_num_gt - except ZeroDivisionError: - recall_e2e = 0 - - try: - precision_e2e = float(hit_str_count) / total_num_det - except ZeroDivisionError: - precision_e2e = 0 - - try: - f_score_e2e = 2 * precision_e2e * recall_e2e / ( - precision_e2e + recall_e2e) - except ZeroDivisionError: - f_score_e2e = 0 - - final = { - 'total_num_gt': total_num_gt, - 'total_num_det': total_num_det, - 'global_accumulative_recall': global_accumulative_recall, - 'hit_str_count': hit_str_count, - 'recall': recall, - 'precision': precision, - 'f_score': f_score, - 'seqerr': seqerr, - 'recall_e2e': recall_e2e, - 'precision_e2e': precision_e2e, - 'f_score_e2e': f_score_e2e - } - return final - - -# def combine_results(all_data): -# tr = 0.7 -# tp = 0.6 -# fsc_k = 0.8 -# k = 2 -# global_sigma = [] -# global_tau = [] -# global_pred_str = [] -# global_gt_str = [] -# for data in all_data: -# global_sigma.append(data['sigma']) -# global_tau.append(data['global_tau']) -# global_pred_str.append(data['global_pred_str']) -# global_gt_str.append(data['global_gt_str']) -# -# global_accumulative_recall = 0 -# global_accumulative_precision = 0 -# total_num_gt = 0 -# total_num_det = 0 -# hit_str_count = 0 -# hit_count = 0 -# -# def one_to_one(local_sigma_table, local_tau_table, local_accumulative_recall, -# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, -# gt_flag, det_flag, idy): -# hit_str_num = 0 -# for gt_id in range(num_gt): -# gt_matching_qualified_sigma_candidates = np.where(local_sigma_table[gt_id, :] > tr) -# gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[0].shape[0] -# gt_matching_qualified_tau_candidates = np.where(local_tau_table[gt_id, :] > tp) -# gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[0].shape[0] -# -# det_matching_qualified_sigma_candidates = np.where( -# local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]] > tr) -# det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[0].shape[0] -# det_matching_qualified_tau_candidates = np.where( -# local_tau_table[:, gt_matching_qualified_tau_candidates[0]] > tp) -# det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[0].shape[0] -# -# if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \ -# (det_matching_num_qualified_sigma_candidates == 1) and ( -# det_matching_num_qualified_tau_candidates == 1): -# global_accumulative_recall = global_accumulative_recall + 1.0 -# global_accumulative_precision = global_accumulative_precision + 1.0 -# local_accumulative_recall = local_accumulative_recall + 1.0 -# local_accumulative_precision = local_accumulative_precision + 1.0 -# -# gt_flag[0, gt_id] = 1 -# matched_det_id = np.where(local_sigma_table[gt_id, :] > tr) -# # recg start -# print -# "liushanshan one to one det_id = {}".format(matched_det_id) -# print -# "liushanshan one to one gt_id = {}".format(gt_id) -# gt_str_cur = global_gt_str[idy][gt_id] -# pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[0]] -# print -# "liushanshan one to one gt_str_cur = {}".format(gt_str_cur) -# print -# "liushanshan one to one pred_str_cur = {}".format(pred_str_cur) -# if pred_str_cur == gt_str_cur: -# hit_str_num += 1 -# else: -# if pred_str_cur.lower() == gt_str_cur.lower(): -# hit_str_num += 1 -# # recg end -# det_flag[0, matched_det_id] = 1 -# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num -# -# def one_to_many(local_sigma_table, local_tau_table, local_accumulative_recall, -# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, -# gt_flag, det_flag, idy): -# hit_str_num = 0 -# for gt_id in range(num_gt): -# # skip the following if the groundtruth was matched -# if gt_flag[0, gt_id] > 0: -# continue -# -# non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0) -# num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0] -# -# if num_non_zero_in_sigma >= k: -# ####search for all detections that overlaps with this groundtruth -# qualified_tau_candidates = np.where((local_tau_table[gt_id, :] >= tp) & (det_flag[0, :] == 0)) -# num_qualified_tau_candidates = qualified_tau_candidates[0].shape[0] -# -# if num_qualified_tau_candidates == 1: -# if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp) and ( -# local_sigma_table[gt_id, qualified_tau_candidates] >= tr)): -# # became an one-to-one case -# global_accumulative_recall = global_accumulative_recall + 1.0 -# global_accumulative_precision = global_accumulative_precision + 1.0 -# local_accumulative_recall = local_accumulative_recall + 1.0 -# local_accumulative_precision = local_accumulative_precision + 1.0 -# -# gt_flag[0, gt_id] = 1 -# det_flag[0, qualified_tau_candidates] = 1 -# # recg start -# print -# "liushanshan one to many det_id = {}".format(qualified_tau_candidates) -# print -# "liushanshan one to many gt_id = {}".format(gt_id) -# gt_str_cur = global_gt_str[idy][gt_id] -# pred_str_cur = global_pred_str[idy][qualified_tau_candidates[0].tolist()[0]] -# print -# "liushanshan one to many gt_str_cur = {}".format(gt_str_cur) -# print -# "liushanshan one to many pred_str_cur = {}".format(pred_str_cur) -# if pred_str_cur == gt_str_cur: -# hit_str_num += 1 -# else: -# if pred_str_cur.lower() == gt_str_cur.lower(): -# hit_str_num += 1 -# # recg end -# elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) >= tr): -# gt_flag[0, gt_id] = 1 -# det_flag[0, qualified_tau_candidates] = 1 -# # recg start -# print -# "liushanshan one to many det_id = {}".format(qualified_tau_candidates) -# print -# "liushanshan one to many gt_id = {}".format(gt_id) -# gt_str_cur = global_gt_str[idy][gt_id] -# pred_str_cur = global_pred_str[idy][qualified_tau_candidates[0].tolist()[0]] -# print -# "liushanshan one to many gt_str_cur = {}".format(gt_str_cur) -# print -# "liushanshan one to many pred_str_cur = {}".format(pred_str_cur) -# if pred_str_cur == gt_str_cur: -# hit_str_num += 1 -# else: -# if pred_str_cur.lower() == gt_str_cur.lower(): -# hit_str_num += 1 -# # recg end -# -# global_accumulative_recall = global_accumulative_recall + fsc_k -# global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k -# -# local_accumulative_recall = local_accumulative_recall + fsc_k -# local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k -# -# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num -# -# def many_to_one(local_sigma_table, local_tau_table, local_accumulative_recall, -# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, -# gt_flag, det_flag, idy): -# hit_str_num = 0 -# for det_id in range(num_det): -# # skip the following if the detection was matched -# if det_flag[0, det_id] > 0: -# continue -# -# non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0) -# num_non_zero_in_tau = non_zero_in_tau[0].shape[0] -# -# if num_non_zero_in_tau >= k: -# ####search for all detections that overlaps with this groundtruth -# qualified_sigma_candidates = np.where((local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0)) -# num_qualified_sigma_candidates = qualified_sigma_candidates[0].shape[0] -# -# if num_qualified_sigma_candidates == 1: -# if ((local_tau_table[qualified_sigma_candidates, det_id] >= tp) and ( -# local_sigma_table[qualified_sigma_candidates, det_id] >= tr)): -# # became an one-to-one case -# global_accumulative_recall = global_accumulative_recall + 1.0 -# global_accumulative_precision = global_accumulative_precision + 1.0 -# local_accumulative_recall = local_accumulative_recall + 1.0 -# local_accumulative_precision = local_accumulative_precision + 1.0 -# -# gt_flag[0, qualified_sigma_candidates] = 1 -# det_flag[0, det_id] = 1 -# # recg start -# print -# "liushanshan many to one det_id = {}".format(det_id) -# print -# "liushanshan many to one gt_id = {}".format(qualified_sigma_candidates) -# pred_str_cur = global_pred_str[idy][det_id] -# gt_len = len(qualified_sigma_candidates[0]) -# for idx in range(gt_len): -# ele_gt_id = qualified_sigma_candidates[0].tolist()[idx] -# if ele_gt_id not in global_gt_str[idy]: -# continue -# gt_str_cur = global_gt_str[idy][ele_gt_id] -# print -# "liushanshan many to one gt_str_cur = {}".format(gt_str_cur) -# print -# "liushanshan many to one pred_str_cur = {}".format(pred_str_cur) -# if pred_str_cur == gt_str_cur: -# hit_str_num += 1 -# break -# else: -# if pred_str_cur.lower() == gt_str_cur.lower(): -# hit_str_num += 1 -# break -# # recg end -# elif (np.sum(local_tau_table[qualified_sigma_candidates, det_id]) >= tp): -# det_flag[0, det_id] = 1 -# gt_flag[0, qualified_sigma_candidates] = 1 -# # recg start -# print -# "liushanshan many to one det_id = {}".format(det_id) -# print -# "liushanshan many to one gt_id = {}".format(qualified_sigma_candidates) -# pred_str_cur = global_pred_str[idy][det_id] -# gt_len = len(qualified_sigma_candidates[0]) -# for idx in range(gt_len): -# ele_gt_id = qualified_sigma_candidates[0].tolist()[idx] -# if not global_gt_str[idy].has_key(ele_gt_id): -# continue -# gt_str_cur = global_gt_str[idy][ele_gt_id] -# print -# "liushanshan many to one gt_str_cur = {}".format(gt_str_cur) -# print -# "liushanshan many to one pred_str_cur = {}".format(pred_str_cur) -# if pred_str_cur == gt_str_cur: -# hit_str_num += 1 -# break -# else: -# if pred_str_cur.lower() == gt_str_cur.lower(): -# hit_str_num += 1 -# break -# else: -# print -# 'no match' -# # recg end -# -# global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k -# global_accumulative_precision = global_accumulative_precision + fsc_k -# -# local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k -# local_accumulative_precision = local_accumulative_precision + fsc_k -# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num -# -# for idx in range(len(global_sigma)): -# local_sigma_table = np.array(global_sigma[idx]) -# local_tau_table = np.array(global_tau[idx]) -# -# num_gt = local_sigma_table.shape[0] -# num_det = local_sigma_table.shape[1] -# -# total_num_gt = total_num_gt + num_gt -# total_num_det = total_num_det + num_det -# -# local_accumulative_recall = 0 -# local_accumulative_precision = 0 -# gt_flag = np.zeros((1, num_gt)) -# det_flag = np.zeros((1, num_det)) -# -# #######first check for one-to-one case########## -# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ -# gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table, -# local_accumulative_recall, local_accumulative_precision, -# global_accumulative_recall, global_accumulative_precision, -# gt_flag, det_flag, idx) -# -# hit_str_count += hit_str_num -# #######then check for one-to-many case########## -# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ -# gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table, -# local_accumulative_recall, local_accumulative_precision, -# global_accumulative_recall, global_accumulative_precision, -# gt_flag, det_flag, idx) -# hit_str_count += hit_str_num -# #######then check for many-to-one case########## -# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ -# gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table, -# local_accumulative_recall, local_accumulative_precision, -# global_accumulative_recall, global_accumulative_precision, -# gt_flag, det_flag, idx) -# try: -# recall = global_accumulative_recall / total_num_gt -# except ZeroDivisionError: -# recall = 0 -# -# try: -# precision = global_accumulative_precision / total_num_det -# except ZeroDivisionError: -# precision = 0 -# -# try: -# f_score = 2 * precision * recall / (precision + recall) -# except ZeroDivisionError: -# f_score = 0 -# -# try: -# seqerr = 1 - float(hit_str_count) / global_accumulative_recall -# except ZeroDivisionError: -# seqerr = 1 -# -# try: -# recall_e2e = float(hit_str_count) / total_num_gt -# except ZeroDivisionError: -# recall_e2e = 0 -# -# try: -# precision_e2e = float(hit_str_count) / total_num_det -# except ZeroDivisionError: -# precision_e2e = 0 -# -# try: -# f_score_e2e = 2 * precision_e2e * recall_e2e / (precision_e2e + recall_e2e) -# except ZeroDivisionError: -# f_score_e2e = 0 -# -# final = { -# 'total_num_gt': total_num_gt, -# 'total_num_det': total_num_det, -# 'global_accumulative_recall': global_accumulative_recall, -# 'hit_str_count': hit_str_count, -# 'recall': recall, -# 'precision': precision, -# 'f_score': f_score, -# 'seqerr': seqerr, -# 'recall_e2e': recall_e2e, -# 'precision_e2e': precision_e2e, -# 'f_score_e2e': f_score_e2e -# } -# return final - -a = [ - 1526, 642, 1565, 629, 1579, 627, 1593, 625, 1607, 623, 1620, 622, 1634, 620, - 1659, 620, 1654, 681, 1631, 680, 1618, 681, 1606, 681, 1594, 681, 1584, 682, - 1573, 685, 1542, 694 -] -gt_dict = [{'points': np.array(a).reshape(-1, 2), 'text': 'MILK'}] -pred_dict = [{ - 'points': np.array(a), - 'text': 'ccc' -}, { - 'points': np.array(a), - 'text': 'ccf' -}] -result = [] -result.append(get_socre(gt_dict, gt_dict)) -a = combine_results(result) -print(a) diff --git a/ppocr/utils/e2e_utils/extract_textpoint.py b/ppocr/utils/e2e_utils/extract_textpoint.py index 96ebf02e..1665c7ef 100644 --- a/ppocr/utils/e2e_utils/extract_textpoint.py +++ b/ppocr/utils/e2e_utils/extract_textpoint.py @@ -1,3 +1,16 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Contains various CTC decoders.""" from __future__ import absolute_import from __future__ import division diff --git a/ppocr/utils/e2e_utils/ski_thin.py b/ppocr/utils/e2e_utils/ski_thin.py index dba2afdd..6b1e5c78 100644 --- a/ppocr/utils/e2e_utils/ski_thin.py +++ b/ppocr/utils/e2e_utils/ski_thin.py @@ -1,6 +1,16 @@ -""" -Algorithms for computing the skeleton of a binary image -""" +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import numpy as np from scipy import ndimage as ndi diff --git a/ppocr/utils/e2e_utils/visual.py b/ppocr/utils/e2e_utils/visual.py index 4c96e5c7..6be2107f 100644 --- a/ppocr/utils/e2e_utils/visual.py +++ b/ppocr/utils/e2e_utils/visual.py @@ -1,147 +1,21 @@ -import os +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import numpy as np import cv2 import time -def visualize_e2e_result(im_fn, poly_list, seq_strs, src_im): - """ - """ - result_path = './out' - im_basename = os.path.basename(im_fn) - im_prefix = im_basename[:im_basename.rfind('.')] - vis_det_img = src_im.copy() - valid_set = 'partvgg' - gt_dir = "/Users/hongyongjie/Downloads/part_vgg_synth/train" - text_path = os.path.join(gt_dir, im_prefix + '.txt') - fid = open(text_path, 'r') - lines = [line.strip() for line in fid.readlines()] - for line in lines: - if valid_set == 'partvgg': - tokens = line.strip().split('\t')[0].split(',') - # tokens = line.strip().split(',') - coords = tokens[:] - coords = list(map(float, coords)) - gt_poly = np.array(coords).reshape(1, 4, 2) - elif valid_set == 'totaltext': - tokens = line.strip().split('\t')[0].split(',') - coords = tokens[:] - coords_len = len(coords) / 2 - coords = list(map(float, coords)) - gt_poly = np.array(coords).reshape(1, coords_len, 2) - cv2.polylines( - vis_det_img, - np.array(gt_poly).astype(np.int32), - isClosed=True, - color=(255, 0, 0), - thickness=2) - - for detected_poly, recognized_str in zip(poly_list, seq_strs): - cv2.polylines( - vis_det_img, - np.array(detected_poly[np.newaxis, ...]).astype(np.int32), - isClosed=True, - color=(0, 0, 255), - thickness=2) - cv2.putText( - vis_det_img, - recognized_str, - org=(int(detected_poly[0, 0]), int(detected_poly[0, 1])), - fontFace=cv2.FONT_HERSHEY_COMPLEX, - fontScale=0.7, - color=(0, 255, 0), - thickness=1) - - if not os.path.exists(result_path): - os.makedirs(result_path) - cv2.imwrite("{}/{}_detection.jpg".format(result_path, im_prefix), - vis_det_img) - - -def visualization_output(src_image, - f_tcl, - f_chars, - output_dir, - image_prefix=None): - """ - """ - # restore BGR image, CHW -> HWC - im_mean = [0.485, 0.456, 0.406] - im_std = [0.229, 0.224, 0.225] - - im_mean = np.array(im_mean).reshape((3, 1, 1)) - im_std = np.array(im_std).reshape((3, 1, 1)) - src_image *= im_std - src_image += im_mean - src_image = src_image.transpose([1, 2, 0]) - src_image = src_image[:, :, ::-1] * 255 # BGR -> RGB - H, W, _ = src_image.shape - - file_prefix = image_prefix if image_prefix is not None else str( - int(time.time() * 1000)) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - # visualization f_tcl - tcl_file_name = os.path.join(output_dir, file_prefix + '_0_tcl.jpg') - vis_tcl_img = src_image.copy() - f_tcl_resized = cv2.resize(f_tcl, dsize=(W, H)) - vis_tcl_img[:, :, 1] = f_tcl_resized * 255 - cv2.imwrite(tcl_file_name, vis_tcl_img) - - # visualization char maps - vis_char_img = src_image.copy() - # CHW -> HWC - char_file_name = os.path.join(output_dir, file_prefix + '_1_chars.jpg') - f_chars = np.argmax(f_chars, axis=2)[:, :, np.newaxis].astype('float32') - f_chars[f_chars < 95] = 1.0 - f_chars[f_chars == 95] = 0.0 - f_chars_resized = cv2.resize(f_chars, dsize=(W, H)) - vis_char_img[:, :, 1] = f_chars_resized * 255 - cv2.imwrite(char_file_name, vis_char_img) - - -def visualize_point_result(im_fn, point_list, point_pair_list, src_im, gt_dir, - result_path): - """ - """ - im_basename = os.path.basename(im_fn) - im_prefix = im_basename[:im_basename.rfind('.')] - vis_det_img = src_im.copy() - - # draw gt bbox on the image. - text_path = os.path.join(gt_dir, im_prefix + '.txt') - fid = open(text_path, 'r') - lines = [line.strip() for line in fid.readlines()] - for line in lines: - tokens = line.strip().split('\t') - coords = tokens[0].split(',') - coords_len = len(coords) - coords = list(map(float, coords)) - gt_poly = np.array(coords).reshape(1, coords_len / 2, 2) - cv2.polylines( - vis_det_img, - np.array(gt_poly).astype(np.int32), - isClosed=True, - color=(255, 255, 255), - thickness=1) - - for point, point_pair in zip(point_list, point_pair_list): - cv2.line( - vis_det_img, - tuple(point_pair[0]), - tuple(point_pair[1]), (0, 255, 255), - thickness=1) - cv2.circle(vis_det_img, tuple(point), 2, (0, 0, 255)) - cv2.circle(vis_det_img, tuple(point_pair[0]), 2, (255, 0, 0)) - cv2.circle(vis_det_img, tuple(point_pair[1]), 2, (0, 255, 0)) - - if not os.path.exists(result_path): - os.makedirs(result_path) - cv2.imwrite("{}/{}_border_points.jpg".format(result_path, im_prefix), - vis_det_img) - - def resize_image(im, max_side_len=512): """ resize image to a size multiple of max_stride which is required by the network @@ -295,49 +169,3 @@ def norm2(x, axis=None): def cos(p1, p2): return (p1 * p2).sum() / (norm2(p1) * norm2(p2)) - - -def generate_direction_info(image_fn, - H, - W, - ratio_h, - ratio_w, - max_length=640, - out_scale=4, - gt_dir=None): - """ - """ - im_basename = os.path.basename(image_fn) - im_prefix = im_basename[:im_basename.rfind('.')] - instance_direction_map = np.zeros(shape=[H // out_scale, W // out_scale, 3]) - - if gt_dir is None: - gt_dir = '/home/vis/huangzuming/data/SYNTH_DATA/part_vgg_synth_icdar/processed/val/poly' - - # get gt label map - text_path = os.path.join(gt_dir, im_prefix + '.txt') - fid = open(text_path, 'r') - lines = [line.strip() for line in fid.readlines()] - for label_idx, line in enumerate(lines, start=1): - coords, txt = line.strip().split('\t') - if txt == '###': - continue - tokens = coords.strip().split(',') - coords = list(map(float, tokens)) - poly = np.array(coords).reshape(4, 2) * np.array( - [ratio_w, ratio_h]).reshape(1, 2) / out_scale - mid_idx = poly.shape[0] // 2 - direct_vector = ( - (poly[mid_idx] + poly[mid_idx - 1]) - (poly[0] + poly[-1])) / 2.0 - - direct_vector /= len(txt) - # l2_distance = norm2(direct_vector) - # avg_char_distance = l2_distance / len(txt) - avg_char_distance = 1.0 - - direct_label = (direct_vector[0], direct_vector[1], avg_char_distance) - cv2.fillPoly(instance_direction_map, - poly.round().astype(np.int32)[np.newaxis, :, :], - direct_label) - instance_direction_map = instance_direction_map.transpose([2, 0, 1]) - return instance_direction_map[:2, ...] diff --git a/tools/program.py b/tools/program.py index 778af8ec..f0e2aa08 100755 --- a/tools/program.py +++ b/tools/program.py @@ -44,7 +44,6 @@ class ArgsParser(ArgumentParser): def parse_args(self, argv=None): args = super(ArgsParser, self).parse_args(argv) - args.config = '/Users/hongyongjie/project/PaddleOCR/configs/e2e/e2e_r50_vd_pg.yml' assert args.config is not None, \ "Please specify --config=configure_file_path." args.opt = self._parse_opt(args.opt) -- GitLab