提交 1f76f449 编写于 作者: J Jethong

Add PGNet

上级 1a087990
......@@ -14,12 +14,13 @@ Global:
load_static_weights: True
cal_metric_during_train: False
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/
checkpoints:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img:
infer_img:
save_res_path: ./output/sast_r50_vd_ic15/predicts_sast.txt
Architecture:
model_type: det
algorithm: SAST
......
Global:
use_gpu: False
epoch_num: 600
log_smooth_window: 20
print_batch_step: 2
save_model_dir: ./output/pg_r50_vd_tt/
save_epoch_step: 1
# evaluation is run every 5000 iterationss after the 4000th iteration
eval_batch_step: [ 0, 1000 ]
# if pretrained_model is saved in static mode, load_static_weights must set to True
load_static_weights: False
cal_metric_during_train: False
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img:
save_res_path: ./output/pg_r50_vd_tt/predicts_pg.txt
Architecture:
model_type: e2e
algorithm: PG
Transform:
Backbone:
name: ResNet
layers: 50
Neck:
name: PGFPN
model_name: large
Head:
name: PGHead
model_name: large
Loss:
name: PGLoss
#Optimizer:
# name: Adam
# beta1: 0.9
# beta2: 0.999
# lr:
# name: Cosine
# learning_rate: 0.001
# warmup_epoch: 1
# regularizer:
# name: 'L2'
# factor: 0
Optimizer:
name: RMSProp
lr:
name: Piecewise
learning_rate: 0.001
decay_epochs: [ 40, 80, 120, 160, 200 ]
values: [ 0.001, 0.00033, 0.0001, 0.000033, 0.00001 ]
regularizer:
name: 'L2'
factor: 0.00005
PostProcess:
name: PGPostProcess
score_thresh: 0.8
cover_thresh: 0.1
nms_thresh: 0.2
Metric:
name: E2EMetric
main_indicator: hmean
Train:
dataset:
name: PGDateSet
label_file_list:
ratio_list:
data_format: textnet # textnet/partvgg
Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- PGProcessTrain:
batch_size: 14
data_format: icdar
tcl_len: 64
min_crop_size: 24
min_text_size: 4
max_text_size: 512
- KeepKeys:
keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order
loader:
shuffle: True
drop_last: True
batch_size_per_card: 1
num_workers: 8
Eval:
dataset:
name: PGDateSet
data_dir: ./train_data/
label_file_list:
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- E2ELabelEncode:
label_list: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
- E2EResizeForTest:
valid_set: totaltext
max_side_len: 768
- NormalizeImage:
scale: 1./255.
mean: [ 0.485, 0.456, 0.406 ]
std: [ 0.229, 0.224, 0.225 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags' ]
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
\ No newline at end of file
......@@ -34,6 +34,7 @@ import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet
from ppocr.data.pgnet_dataset import PGDateSet
__all__ = ['build_dataloader', 'transform', 'create_operators']
......@@ -54,7 +55,8 @@ signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDataSet']
support_dict = ['SimpleDataSet', 'LMDBDateSet', 'PGDateSet']
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict))
......
......@@ -28,6 +28,7 @@ from .label_ops import *
from .east_process import *
from .sast_process import *
from .pg_process import *
def transform(data, ops=None):
......
......@@ -34,6 +34,25 @@ class ClsLabelEncode(object):
return data
class E2ELabelEncode(object):
def __init__(self, label_list, **kwargs):
self.label_list = label_list
def __call__(self, data):
text_label_index_list, temp_text = [], []
texts = data['strs']
for text in texts:
text = text.upper()
temp_text = []
for c_ in text:
if c_ in self.label_list:
temp_text.append(self.label_list.index(c_))
temp_text = temp_text + [36] * (50 - len(temp_text))
text_label_index_list.append(temp_text)
data['strs'] = np.array(text_label_index_list)
return data
class DetLabelEncode(object):
def __init__(self, **kwargs):
pass
......
......@@ -223,3 +223,74 @@ class DetResizeForTest(object):
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
class E2EResizeForTest(object):
def __init__(self, **kwargs):
super(E2EResizeForTest, self).__init__()
self.max_side_len = kwargs['max_side_len']
self.valid_set = kwargs['valid_set']
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if self.valid_set == 'totaltext':
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
img, max_side_len=self.max_side_len)
else:
im_resized, (ratio_h, ratio_w) = self.resize_image(
img, max_side_len=self.max_side_len)
data['image'] = im_resized
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_for_totaltext(self, im, max_side_len=512):
"""
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image(self, im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import cv2
import numpy as np
import os
__all__ = ['PGProcessTrain']
class PGProcessTrain(object):
def __init__(self,
batch_size=14,
data_format='icdar',
tcl_len=64,
min_crop_size=24,
min_text_size=10,
max_text_size=512,
**kwargs):
self.batch_size = batch_size
self.data_format = data_format
self.tcl_len = tcl_len
self.min_crop_size = min_crop_size
self.min_text_size = min_text_size
self.max_text_size = max_text_size
self.Lexicon_Table = [
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C',
'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'
]
self.img_id = 0
def quad_area(self, poly):
"""
compute area of a polygon
:param poly:
:return:
"""
edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
return np.sum(edge) / 2.
def gen_quad_from_poly(self, poly):
"""
Generate min area quad from poly.
"""
point_num = poly.shape[0]
min_area_quad = np.zeros((4, 2), dtype=np.float32)
if True:
rect = cv2.minAreaRect(poly.astype(
np.int32)) # (center (x,y), (width, height), angle of rotation)
center_point = rect[0]
box = np.array(cv2.boxPoints(rect))
first_point_idx = 0
min_dist = 1e4
for i in range(4):
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
if dist < min_dist:
min_dist = dist
first_point_idx = i
for i in range(4):
min_area_quad[i] = box[(first_point_idx + i) % 4]
return min_area_quad
def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
"""
check so that the text poly is in the same direction,
and also filter some invalid polygons
:param polys:
:param tags:
:return:
"""
(h, w) = xxx_todo_changeme
if polys.shape[0] == 0:
return polys, np.array([]), np.array([])
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
validated_polys = []
validated_tags = []
hv_tags = []
for poly, tag in zip(polys, tags):
quad = self.gen_quad_from_poly(poly)
p_area = self.quad_area(quad)
if abs(p_area) < 1:
print('invalid poly')
continue
if p_area > 0:
if tag == False:
print('poly in wrong direction')
tag = True # reversed cases should be ignore
poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
1), :]
quad = quad[(0, 3, 2, 1), :]
len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
quad[2])
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
quad[2])
hv_tag = 1
if len_w * 2.0 < len_h:
hv_tag = 0
validated_polys.append(poly)
validated_tags.append(tag)
hv_tags.append(hv_tag)
return np.array(validated_polys), np.array(validated_tags), np.array(
hv_tags)
def crop_area(self,
im,
polys,
tags,
hv_tags,
txts,
crop_background=False,
max_tries=25):
"""
make random crop from the input image
:param im:
:param polys: [b,4,2]
:param tags:
:param crop_background:
:param max_tries: 50 -> 25
:return:
"""
h, w, _ = im.shape
pad_h = h // 10
pad_w = w // 10
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
for poly in polys:
poly = np.round(poly, decimals=0).astype(np.int32)
minx = np.min(poly[:, 0])
maxx = np.max(poly[:, 0])
w_array[minx + pad_w:maxx + pad_w] = 1
miny = np.min(poly[:, 1])
maxy = np.max(poly[:, 1])
h_array[miny + pad_h:maxy + pad_h] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
if len(h_axis) == 0 or len(w_axis) == 0:
return im, polys, tags, hv_tags, txts
for i in range(max_tries):
xx = np.random.choice(w_axis, size=2)
xmin = np.min(xx) - pad_w
xmax = np.max(xx) - pad_w
xmin = np.clip(xmin, 0, w - 1)
xmax = np.clip(xmax, 0, w - 1)
yy = np.random.choice(h_axis, size=2)
ymin = np.min(yy) - pad_h
ymax = np.max(yy) - pad_h
ymin = np.clip(ymin, 0, h - 1)
ymax = np.clip(ymax, 0, h - 1)
if xmax - xmin < self.min_crop_size or \
ymax - ymin < self.min_crop_size:
continue
if polys.shape[0] != 0:
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
selected_polys = np.where(
np.sum(poly_axis_in_area, axis=1) == 4)[0]
else:
selected_polys = []
if len(selected_polys) == 0:
# no text in this area
if crop_background:
txts_tmp = []
for selected_poly in selected_polys:
txts_tmp.append(txts[selected_poly])
txts = txts_tmp
# print(1111)
return im[ymin: ymax + 1, xmin: xmax + 1, :], \
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
else:
continue
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
polys = polys[selected_polys]
tags = tags[selected_polys]
hv_tags = hv_tags[selected_polys]
txts_tmp = []
for selected_poly in selected_polys:
txts_tmp.append(txts[selected_poly])
txts = txts_tmp
polys[:, :, 0] -= xmin
polys[:, :, 1] -= ymin
return im, polys, tags, hv_tags, txts
return im, polys, tags, hv_tags, txts
def fit_and_gather_tcl_points_v2(self,
min_area_quad,
poly,
max_h,
max_w,
fixed_point_num=64,
img_id=0,
reference_height=3):
"""
Find the center point of poly as key_points, then fit and gather.
"""
key_point_xys = []
point_num = poly.shape[0]
for idx in range(point_num // 2):
center_point = (poly[idx] + poly[point_num - 1 - idx]) / 2.0
key_point_xys.append(center_point)
tmp_image = np.zeros(
shape=(
max_h,
max_w, ), dtype='float32')
cv2.polylines(tmp_image, [np.array(key_point_xys).astype('int32')],
False, 1.0)
ys, xs = np.where(tmp_image > 0)
xy_text = np.array(list(zip(xs, ys)), dtype='float32')
# left_center_pt = np.array(key_point_xys[0]).reshape(1, 2)
# right_center_pt = np.array(key_point_xys[-1]).reshape(1, 2)
left_center_pt = (
(min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2)
right_center_pt = (
(min_area_quad[1] - min_area_quad[2]) / 2.0).reshape(1, 2)
proj_unit_vec = (right_center_pt - left_center_pt) / (
np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
proj_unit_vec_tile = np.tile(proj_unit_vec,
(xy_text.shape[0], 1)) # (n, 2)
left_center_pt_tile = np.tile(left_center_pt,
(xy_text.shape[0], 1)) # (n, 2)
xy_text_to_left_center = xy_text - left_center_pt_tile
proj_value = np.sum(xy_text_to_left_center * proj_unit_vec_tile, axis=1)
xy_text = xy_text[np.argsort(proj_value)]
# convert to np and keep the num of point not greater then fixed_point_num
pos_info = np.array(xy_text).reshape(-1, 2)[:, ::-1] # xy-> yx
point_num = len(pos_info)
if point_num > fixed_point_num:
keep_ids = [
int((point_num * 1.0 / fixed_point_num) * x)
for x in range(fixed_point_num)
]
pos_info = pos_info[keep_ids, :]
keep = int(min(len(pos_info), fixed_point_num))
if np.random.rand() < 0.2 and reference_height >= 3:
dl = (np.random.rand(keep) - 0.5) * reference_height * 0.3
random_float = np.array([1, 0]).reshape([1, 2]) * dl.reshape(
[keep, 1])
pos_info += random_float
pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
# padding to fixed length
pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id
pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
pos_m[:keep] = 1.0
return pos_l, pos_m
def generate_direction_map(self, poly_quads, n_char, direction_map):
"""
"""
width_list = []
height_list = []
for quad in poly_quads:
quad_w = (np.linalg.norm(quad[0] - quad[1]) +
np.linalg.norm(quad[2] - quad[3])) / 2.0
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
np.linalg.norm(quad[2] - quad[1])) / 2.0
width_list.append(quad_w)
height_list.append(quad_h)
norm_width = max(sum(width_list) / n_char, 1.0)
average_height = max(sum(height_list) / len(height_list), 1.0)
for quad in poly_quads:
direct_vector_full = (
(quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
direct_vector = direct_vector_full / (
np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
direction_label = tuple(
map(float,
[direct_vector[0], direct_vector[1], 1.0 / average_height]))
cv2.fillPoly(direction_map,
quad.round().astype(np.int32)[np.newaxis, :, :],
direction_label)
return direction_map
def calculate_average_height(self, poly_quads):
"""
"""
height_list = []
for quad in poly_quads:
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
np.linalg.norm(quad[2] - quad[1])) / 2.0
height_list.append(quad_h)
average_height = max(sum(height_list) / len(height_list), 1.0)
return average_height
def encode(self, text):
text_list = []
for char in text:
if char not in self.dict:
continue
text_list.append([self.dict[char]])
if len(text_list) == 0:
return None
return text_list
def generate_tcl_ctc_label(self,
h,
w,
polys,
tags,
text_strs,
ds_ratio,
tcl_ratio=0.3,
shrink_ratio_of_width=0.15):
"""
Generate polygon.
"""
score_map_big = np.zeros(
(
h,
w, ), dtype=np.float32)
h, w = int(h * ds_ratio), int(w * ds_ratio)
polys = polys * ds_ratio
score_map = np.zeros(
(
h,
w, ), dtype=np.float32)
score_label_map = np.zeros(
(
h,
w, ), dtype=np.float32)
tbo_map = np.zeros((h, w, 5), dtype=np.float32)
training_mask = np.ones(
(
h,
w, ), dtype=np.float32)
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
[1, 1, 3]).astype(np.float32)
label_idx = 0
score_label_map_text_label_list = []
pos_list, pos_mask, label_list = [], [], []
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
poly = poly_tag[0]
tag = poly_tag[1]
# generate min_area_quad
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
min_area_quad_h = 0.5 * (
np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
min_area_quad_w = 0.5 * (
np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
continue
if tag:
# continue
cv2.fillPoly(training_mask,
poly.astype(np.int32)[np.newaxis, :, :], 0.15)
else:
text_label = text_strs[poly_idx]
text_label = self.prepare_text_label(text_label,
self.Lexicon_Table)
# text = text.decode('utf-8')
# text_label_index_list = self.encode(text)
text_label_index_list = [[self.Lexicon_Table.index(c_)]
for c_ in text_label
if c_ in self.Lexicon_Table]
if len(text_label_index_list) < 1:
continue
tcl_poly = self.poly2tcl(poly, tcl_ratio)
tcl_quads = self.poly2quads(tcl_poly)
poly_quads = self.poly2quads(poly)
# stcl map
stcl_quads, quad_index = self.shrink_poly_along_width(
tcl_quads,
shrink_ratio_of_width=shrink_ratio_of_width,
expand_height_ratio=1.0 / tcl_ratio)
# generate tcl map
cv2.fillPoly(score_map,
np.round(stcl_quads).astype(np.int32), 1.0)
cv2.fillPoly(score_map_big,
np.round(stcl_quads / ds_ratio).astype(np.int32),
1.0)
# generate tbo map
# tbo_tcl_poly = poly2tcl(poly, 0.5)
# tbo_tcl_quads = poly2quads(tbo_tcl_poly)
# for idx, quad in enumerate(tbo_tcl_quads):
for idx, quad in enumerate(stcl_quads):
quad_mask = np.zeros((h, w), dtype=np.float32)
quad_mask = cv2.fillPoly(
quad_mask,
np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
quad_mask, tbo_map)
# score label map and score_label_map_text_label_list for refine
if label_idx == 0:
text_pos_list_ = [[len(self.Lexicon_Table)], ]
score_label_map_text_label_list.append(text_pos_list_)
label_idx += 1
# cv2.fillPoly(score_label_map, np.round(poly_quads[np.newaxis, :, :]).astype(np.int32), label_idx)
cv2.fillPoly(score_label_map,
np.round(poly_quads).astype(np.int32), label_idx)
score_label_map_text_label_list.append(text_label_index_list)
# direction info, fix-me
n_char = len(text_label_index_list)
direction_map = self.generate_direction_map(poly_quads, n_char,
direction_map)
# pos info
average_shrink_height = self.calculate_average_height(
stcl_quads)
pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
min_area_quad,
poly,
max_h=h,
max_w=w,
fixed_point_num=64,
img_id=self.img_id,
reference_height=average_shrink_height)
label_l = text_label_index_list
if len(text_label_index_list) < 2:
continue
pos_list.append(pos_l)
pos_mask.append(pos_m)
label_list.append(label_l)
# use big score_map for smooth tcl lines
score_map_big_resized = cv2.resize(
score_map_big, dsize=None, fx=ds_ratio, fy=ds_ratio)
score_map = np.array(score_map_big_resized > 1e-3, dtype='float32')
return score_map, score_label_map, tbo_map, direction_map, training_mask, \
pos_list, pos_mask, label_list, score_label_map_text_label_list
def adjust_point(self, poly):
"""
adjust point order.
"""
point_num = poly.shape[0]
if point_num == 4:
len_1 = np.linalg.norm(poly[0] - poly[1])
len_2 = np.linalg.norm(poly[1] - poly[2])
len_3 = np.linalg.norm(poly[2] - poly[3])
len_4 = np.linalg.norm(poly[3] - poly[0])
if (len_1 + len_3) * 1.5 < (len_2 + len_4):
poly = poly[[1, 2, 3, 0], :]
elif point_num > 4:
vector_1 = poly[0] - poly[1]
vector_2 = poly[1] - poly[2]
cos_theta = np.dot(vector_1, vector_2) / (
np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
theta = np.arccos(np.round(cos_theta, decimals=4))
if abs(theta) > (70 / 180 * math.pi):
index = list(range(1, point_num)) + [0]
poly = poly[np.array(index), :]
return poly
def gen_min_area_quad_from_poly(self, poly):
"""
Generate min area quad from poly.
"""
point_num = poly.shape[0]
min_area_quad = np.zeros((4, 2), dtype=np.float32)
if point_num == 4:
min_area_quad = poly
center_point = np.sum(poly, axis=0) / 4
else:
rect = cv2.minAreaRect(poly.astype(
np.int32)) # (center (x,y), (width, height), angle of rotation)
center_point = rect[0]
box = np.array(cv2.boxPoints(rect))
first_point_idx = 0
min_dist = 1e4
for i in range(4):
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
if dist < min_dist:
min_dist = dist
first_point_idx = i
for i in range(4):
min_area_quad[i] = box[(first_point_idx + i) % 4]
return min_area_quad, center_point
def shrink_quad_along_width(self,
quad,
begin_width_ratio=0.,
end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def shrink_poly_along_width(self,
quads,
shrink_ratio_of_width,
expand_height_ratio=1.0):
"""
shrink poly with given length.
"""
upper_edge_list = []
def get_cut_info(edge_len_list, cut_len):
for idx, edge_len in enumerate(edge_len_list):
cut_len -= edge_len
if cut_len <= 0.000001:
ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
return idx, ratio
for quad in quads:
upper_edge_len = np.linalg.norm(quad[0] - quad[1])
upper_edge_list.append(upper_edge_len)
# length of left edge and right edge.
left_length = np.linalg.norm(quads[0][0] - quads[0][
3]) * expand_height_ratio
right_length = np.linalg.norm(quads[-1][1] - quads[-1][
2]) * expand_height_ratio
shrink_length = min(left_length, right_length,
sum(upper_edge_list)) * shrink_ratio_of_width
# shrinking length
upper_len_left = shrink_length
upper_len_right = sum(upper_edge_list) - shrink_length
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
left_quad = self.shrink_quad_along_width(
quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
right_quad = self.shrink_quad_along_width(
quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
out_quad_list = []
if left_idx == right_idx:
out_quad_list.append(
[left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
else:
out_quad_list.append(left_quad)
for idx in range(left_idx + 1, right_idx):
out_quad_list.append(quads[idx])
out_quad_list.append(right_quad)
return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
def prepare_text_label(self, label_str, Lexicon_Table):
"""
Prepare text lablel by given Lexicon_Table.
"""
if len(Lexicon_Table) == 36:
return label_str.upper()
else:
return label_str
def vector_angle(self, A, B):
"""
Calculate the angle between vector AB and x-axis positive direction.
"""
AB = np.array([B[1] - A[1], B[0] - A[0]])
return np.arctan2(*AB)
def theta_line_cross_point(self, theta, point):
"""
Calculate the line through given point and angle in ax + by + c =0 form.
"""
x, y = point
cos = np.cos(theta)
sin = np.sin(theta)
return [sin, -cos, cos * y - sin * x]
def line_cross_two_point(self, A, B):
"""
Calculate the line through given point A and B in ax + by + c =0 form.
"""
angle = self.vector_angle(A, B)
return self.theta_line_cross_point(angle, A)
def average_angle(self, poly):
"""
Calculate the average angle between left and right edge in given poly.
"""
p0, p1, p2, p3 = poly
angle30 = self.vector_angle(p3, p0)
angle21 = self.vector_angle(p2, p1)
return (angle30 + angle21) / 2
def line_cross_point(self, line1, line2):
"""
line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
"""
a1, b1, c1 = line1
a2, b2, c2 = line2
d = a1 * b2 - a2 * b1
if d == 0:
# print("line1", line1)
# print("line2", line2)
print('Cross point does not exist')
return np.array([0, 0], dtype=np.float32)
else:
x = (b1 * c2 - b2 * c1) / d
y = (a2 * c1 - a1 * c2) / d
return np.array([x, y], dtype=np.float32)
def quad2tcl(self, poly, ratio):
"""
Generate center line by poly clock-wise point. (4, 2)
"""
ratio_pair = np.array(
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
def poly2tcl(self, poly, ratio):
"""
Generate center line by poly clock-wise point.
"""
ratio_pair = np.array(
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
tcl_poly = np.zeros_like(poly)
point_num = poly.shape[0]
for idx in range(point_num // 2):
point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
) * ratio_pair
tcl_poly[idx] = point_pair[0]
tcl_poly[point_num - 1 - idx] = point_pair[1]
return tcl_poly
def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
"""
Generate tbo_map for give quad.
"""
# upper and lower line function: ax + by + c = 0;
up_line = self.line_cross_two_point(quad[0], quad[1])
lower_line = self.line_cross_two_point(quad[3], quad[2])
quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
np.linalg.norm(quad[1] - quad[2]))
quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
np.linalg.norm(quad[2] - quad[3]))
# average angle of left and right line.
angle = self.average_angle(quad)
xy_in_poly = np.argwhere(tcl_mask == 1)
for y, x in xy_in_poly:
point = (x, y)
line = self.theta_line_cross_point(angle, point)
cross_point_upper = self.line_cross_point(up_line, line)
cross_point_lower = self.line_cross_point(lower_line, line)
##FIX, offset reverse
upper_offset_x, upper_offset_y = cross_point_upper - point
lower_offset_x, lower_offset_y = cross_point_lower - point
tbo_map[y, x, 0] = upper_offset_y
tbo_map[y, x, 1] = upper_offset_x
tbo_map[y, x, 2] = lower_offset_y
tbo_map[y, x, 3] = lower_offset_x
tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
return tbo_map
def poly2quads(self, poly):
"""
Split poly into quads.
"""
quad_list = []
point_num = poly.shape[0]
# point pair
point_pair_list = []
for idx in range(point_num // 2):
point_pair = [poly[idx], poly[point_num - 1 - idx]]
point_pair_list.append(point_pair)
quad_num = point_num // 2 - 1
for idx in range(quad_num):
# reshape and adjust to clock-wise
quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
).reshape(4, 2)[[0, 2, 3, 1]])
return np.array(quad_list)
def rotate_im_poly(self, im, text_polys):
"""
rotate image with 90 / 180 / 270 degre
"""
im_w, im_h = im.shape[1], im.shape[0]
dst_im = im.copy()
dst_polys = []
rand_degree_ratio = np.random.rand()
rand_degree_cnt = 1
if rand_degree_ratio > 0.5:
rand_degree_cnt = 3
for i in range(rand_degree_cnt):
dst_im = np.rot90(dst_im)
rot_degree = -90 * rand_degree_cnt
rot_angle = rot_degree * math.pi / 180.0
n_poly = text_polys.shape[0]
cx, cy = 0.5 * im_w, 0.5 * im_h
ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
for i in range(n_poly):
wordBB = text_polys[i]
poly = []
for j in range(4): # 16->4
sx, sy = wordBB[j][0], wordBB[j][1]
dx = math.cos(rot_angle) * (sx - cx) - math.sin(rot_angle) * (
sy - cy) + ncx
dy = math.sin(rot_angle) * (sx - cx) + math.cos(rot_angle) * (
sy - cy) + ncy
poly.append([dx, dy])
dst_polys.append(poly)
return dst_im, np.array(dst_polys, dtype=np.float32)
def __call__(self, data):
input_size = 512
im = data['image']
text_polys = data['polys']
text_tags = data['tags']
text_strs = data['strs']
h, w, _ = im.shape
text_polys, text_tags, hv_tags = self.check_and_validate_polys(
text_polys, text_tags, (h, w))
if text_polys.shape[0] <= 0:
return None
# set aspect ratio and keep area fix
asp_scales = np.arange(1.0, 1.55, 0.1)
asp_scale = np.random.choice(asp_scales)
if np.random.rand() < 0.5:
asp_scale = 1.0 / asp_scale
asp_scale = math.sqrt(asp_scale)
asp_wx = asp_scale
asp_hy = 1.0 / asp_scale
im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
text_polys[:, :, 0] *= asp_wx
text_polys[:, :, 1] *= asp_hy
h, w, _ = im.shape
if max(h, w) > 2048:
rd_scale = 2048.0 / max(h, w)
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
text_polys *= rd_scale
h, w, _ = im.shape
if min(h, w) < 16:
return None
# no background
im, text_polys, text_tags, hv_tags, text_strs = self.crop_area(
im,
text_polys,
text_tags,
hv_tags,
text_strs,
crop_background=False)
if text_polys.shape[0] == 0:
return None
# # continue for all ignore case
if np.sum((text_tags * 1.0)) >= text_tags.size:
return None
new_h, new_w, _ = im.shape
if (new_h is None) or (new_w is None):
return None
# resize image
std_ratio = float(input_size) / max(new_w, new_h)
rand_scales = np.array(
[0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
rz_scale = std_ratio * np.random.choice(rand_scales)
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
text_polys[:, :, 0] *= rz_scale
text_polys[:, :, 1] *= rz_scale
# add gaussian blur
if np.random.rand() < 0.1 * 0.5:
ks = np.random.permutation(5)[0] + 1
ks = int(ks / 2) * 2 + 1
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
# add brighter
if np.random.rand() < 0.1 * 0.5:
im = im * (1.0 + np.random.rand() * 0.5)
im = np.clip(im, 0.0, 255.0)
# add darker
if np.random.rand() < 0.1 * 0.5:
im = im * (1.0 - np.random.rand() * 0.5)
im = np.clip(im, 0.0, 255.0)
# Padding the im to [input_size, input_size]
new_h, new_w, _ = im.shape
if min(new_w, new_h) < input_size * 0.5:
return None
im_padded = np.ones((input_size, input_size, 3), dtype=np.float32)
im_padded[:, :, 2] = 0.485 * 255
im_padded[:, :, 1] = 0.456 * 255
im_padded[:, :, 0] = 0.406 * 255
# Random the start position
del_h = input_size - new_h
del_w = input_size - new_w
sh, sw = 0, 0
if del_h > 1:
sh = int(np.random.rand() * del_h)
if del_w > 1:
sw = int(np.random.rand() * del_w)
# Padding
im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
text_polys[:, :, 0] += sw
text_polys[:, :, 1] += sh
score_map, score_label_map, border_map, direction_map, training_mask, \
pos_list, pos_mask, label_list, score_label_map_text_label = self.generate_tcl_ctc_label(input_size,
input_size,
text_polys,
text_tags,
text_strs, 0.25)
if len(label_list) <= 0: # eliminate negative samples
return None
pos_list_temp = np.zeros([64, 3])
pos_mask_temp = np.zeros([64, 1])
label_list_temp = np.zeros([50, 1]) + 36
for i, label in enumerate(label_list):
n = len(label)
if n > 50:
label_list[i] = label[:50]
continue
while n < 50:
label.append([36])
n += 1
for i in range(len(label_list)):
label_list[i] = np.array(label_list[i])
if len(pos_list) <= 0 or len(pos_list) > 30:
return None
for __ in range(30 - len(pos_list), 0, -1):
pos_list.append(pos_list_temp)
pos_mask.append(pos_mask_temp)
label_list.append(label_list_temp)
if self.img_id == self.batch_size - 1:
self.img_id = 0
else:
self.img_id += 1
im_padded[:, :, 2] -= 0.485 * 255
im_padded[:, :, 1] -= 0.456 * 255
im_padded[:, :, 0] -= 0.406 * 255
im_padded[:, :, 2] /= (255.0 * 0.229)
im_padded[:, :, 1] /= (255.0 * 0.224)
im_padded[:, :, 0] /= (255.0 * 0.225)
im_padded = im_padded.transpose((2, 0, 1))
images = im_padded[::-1, :, :]
tcl_maps = score_map[np.newaxis, :, :]
tcl_label_maps = score_label_map[np.newaxis, :, :]
border_maps = border_map.transpose((2, 0, 1))
direction_maps = direction_map.transpose((2, 0, 1))
training_masks = training_mask[np.newaxis, :, :]
pos_list = np.array(pos_list)
pos_mask = np.array(pos_mask)
label_list = np.array(label_list)
data['images'] = images
data['tcl_maps'] = tcl_maps
data['tcl_label_maps'] = tcl_label_maps
data['border_maps'] = border_maps
data['direction_maps'] = direction_maps
data['training_masks'] = training_masks
data['label_list'] = label_list
data['pos_list'] = pos_list
data['pos_mask'] = pos_mask
return data
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
from paddle.io import Dataset
from .imaug import transform, create_operators
import random
class PGDateSet(Dataset):
def __init__(self, config, mode, logger):
super(PGDateSet, self).__init__()
self.logger = logger
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
self.data_format = dataset_config.get('data_format', 'icdar')
assert len(
ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list."
# self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list,
self.data_format)
self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train":
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
def shuffle_data_random(self):
if self.do_shuffle:
random.shuffle(self.data_lines)
return
def extract_polys(self, poly_txt_path):
"""
Read text_polys, txt_tags, txts from give txt file.
"""
text_polys, txt_tags, txts = [], [], []
with open(poly_txt_path) as f:
for line in f.readlines():
poly_str, txt = line.strip().split('\t')
poly = map(float, poly_str.split(','))
text_polys.append(
np.array(
list(poly), dtype=np.float32).reshape(-1, 2))
txts.append(txt)
if txt == '###':
txt_tags.append(True)
else:
txt_tags.append(False)
return np.array(list(map(np.array, text_polys))), \
np.array(txt_tags, dtype=np.bool), txts
def extract_info_textnet(self, im_fn, img_dir=''):
"""
Extract information from line in textnet format.
"""
info_list = im_fn.split('\t')
img_path = ''
for ext in ['.jpg', '.png', '.jpeg', '.JPG']:
if os.path.exists(os.path.join(img_dir, info_list[0] + ext)):
img_path = os.path.join(img_dir, info_list[0] + ext)
break
if img_path == '':
print('Image {0} NOT found in {1}, and it will be ignored.'.format(
info_list[0], img_dir))
nBox = (len(info_list) - 1) // 9
wordBBs, txts, txt_tags = [], [], []
for n in range(0, nBox):
wordBB = list(map(float, info_list[n * 9 + 1:(n + 1) * 9]))
txt = info_list[(n + 1) * 9]
wordBBs.append([[wordBB[0], wordBB[1]], [wordBB[2], wordBB[3]],
[wordBB[4], wordBB[5]], [wordBB[6], wordBB[7]]])
txts.append(txt)
if txt == '###':
txt_tags.append(True)
else:
txt_tags.append(False)
return img_path, np.array(wordBBs, dtype=np.float32), txt_tags, txts
def get_image_info_list(self, file_list, ratio_list, data_format='textnet'):
if isinstance(file_list, str):
file_list = [file_list]
data_lines = []
for idx, data_source in enumerate(file_list):
image_files = []
if data_format == 'icdar':
image_files = [
(data_source, x)
for x in os.listdir(os.path.join(data_source, 'rgb'))
if x.split('.')[-1] in ['jpg', 'png', 'jpeg', 'JPG']
]
elif data_format == 'textnet':
with open(data_source) as f:
image_files = [(data_source, x.strip())
for x in f.readlines()]
else:
print("Unrecognized data format...")
exit(-1)
image_files = random.sample(
image_files, round(len(image_files) * ratio_list[idx]))
data_lines.extend(image_files)
return data_lines
def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx]
data_path, data_line = self.data_lines[file_idx]
try:
if self.data_format == 'icdar':
im_path = os.path.join(data_path, 'rgb', data_line)
poly_path = os.path.join(data_path, 'poly',
data_line.split('.')[0] + '.txt')
text_polys, text_tags, text_strs = self.extract_polys(poly_path)
else:
image_dir = os.path.join(os.path.dirname(data_path), 'image')
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet(
data_line, image_dir)
data = {
'img_path': im_path,
'polys': text_polys,
'tags': text_tags,
'strs': text_strs
}
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
except Exception as e:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
self.data_idx_order_list[idx], e))
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_idx_order_list)
......@@ -29,10 +29,11 @@ def build_loss(config):
# cls loss
from .cls_loss import ClsLoss
# e2e loss
from .e2e_pg_loss import PGLoss
support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss'
]
'SRNLoss', 'PGLoss']
config = copy.deepcopy(config)
module_name = config.pop('name')
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
import paddle
import numpy as np
import copy
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
class PGLoss(nn.Layer):
"""
Differentiable Binarization (DB) Loss Function
args:
param (dict): the super paramter for DB Loss
"""
def __init__(self, alpha=5, beta=10, eps=1e-6, **kwargs):
super(PGLoss, self).__init__()
self.alpha = alpha
self.beta = beta
self.dice_loss = DiceLoss(eps=eps)
def org_tcl_rois(self, batch_size, pos_lists, pos_masks, label_lists):
"""
"""
pos_lists_, pos_masks_, label_lists_ = [], [], []
img_bs = batch_size
tcl_bs = 64
ngpu = int(batch_size / img_bs)
img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy()
pos_lists_split, pos_masks_split, label_lists_split = [], [], []
for i in range(ngpu):
pos_lists_split.append([])
pos_masks_split.append([])
label_lists_split.append([])
for i in range(img_ids.shape[0]):
img_id = img_ids[i]
gpu_id = int(img_id / img_bs)
img_id = img_id % img_bs
pos_list = pos_lists[i].copy()
pos_list[:, 0] = img_id
pos_lists_split[gpu_id].append(pos_list)
pos_masks_split[gpu_id].append(pos_masks[i].copy())
label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i]))
# repeat or delete
for i in range(ngpu):
vp_len = len(pos_lists_split[i])
if vp_len <= tcl_bs:
for j in range(0, tcl_bs - vp_len):
pos_list = pos_lists_split[i][j].copy()
pos_lists_split[i].append(pos_list)
pos_mask = pos_masks_split[i][j].copy()
pos_masks_split[i].append(pos_mask)
label_list = copy.deepcopy(label_lists_split[i][j])
label_lists_split[i].append(label_list)
else:
for j in range(0, vp_len - tcl_bs):
c_len = len(pos_lists_split[i])
pop_id = np.random.permutation(c_len)[0]
pos_lists_split[i].pop(pop_id)
pos_masks_split[i].pop(pop_id)
label_lists_split[i].pop(pop_id)
# merge
for i in range(ngpu):
pos_lists_.extend(pos_lists_split[i])
pos_masks_.extend(pos_masks_split[i])
label_lists_.extend(label_lists_split[i])
return pos_lists_, pos_masks_, label_lists_
def pre_process(self, label_list, pos_list, pos_mask):
label_list = label_list.numpy()
b, h, w, c = label_list.shape
pos_list = pos_list.numpy()
pos_mask = pos_mask.numpy()
pos_list_t = []
pos_mask_t = []
label_list_t = []
for i in range(b):
for j in range(30):
if pos_mask[i, j].any():
pos_list_t.append(pos_list[i][j])
pos_mask_t.append(pos_mask[i][j])
label_list_t.append(label_list[i][j])
pos_list, pos_mask, label_list = self.org_tcl_rois(
b, pos_list_t, pos_mask_t, label_list_t)
label = []
tt = [l.tolist() for l in label_list]
for i in range(64):
k = 0
for j in range(50):
if tt[i][j][0] != 36:
k += 1
else:
break
label.append(k)
label = paddle.to_tensor(label)
label = paddle.cast(label, dtype='int64')
pos_list = paddle.to_tensor(pos_list)
pos_mask = paddle.to_tensor(pos_mask)
label_list = paddle.squeeze(paddle.to_tensor(label_list), axis=2)
label_list = paddle.cast(label_list, dtype='int32')
return pos_list, pos_mask, label_list, label
def border_loss(self, f_border, l_border, l_score, l_mask):
l_border_split, l_border_norm = paddle.tensor.split(
l_border, num_or_sections=[4, 1], axis=1)
f_border_split = f_border
b, c, h, w = l_border_norm.shape
l_border_norm_split = paddle.expand(
x=l_border_norm, shape=[b, 4 * c, h, w])
b, c, h, w = l_score.shape
l_border_score = paddle.expand(x=l_score, shape=[b, 4 * c, h, w])
b, c, h, w = l_mask.shape
l_border_mask = paddle.expand(x=l_mask, shape=[b, 4 * c, h, w])
border_diff = l_border_split - f_border_split
abs_border_diff = paddle.abs(border_diff)
border_sign = abs_border_diff < 1.0
border_sign = paddle.cast(border_sign, dtype='float32')
border_sign.stop_gradient = True
border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
(abs_border_diff - 0.5) * (1.0 - border_sign)
border_out_loss = l_border_norm_split * border_in_loss
border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
(paddle.sum(l_border_score * l_border_mask) + 1e-5)
return border_loss
def direction_loss(self, f_direction, l_direction, l_score, l_mask):
l_direction_split, l_direction_norm = paddle.tensor.split(
l_direction, num_or_sections=[2, 1], axis=1)
f_direction_split = f_direction
b, c, h, w = l_direction_norm.shape
l_direction_norm_split = paddle.expand(
x=l_direction_norm, shape=[b, 2 * c, h, w])
b, c, h, w = l_score.shape
l_direction_score = paddle.expand(x=l_score, shape=[b, 2 * c, h, w])
b, c, h, w = l_mask.shape
l_direction_mask = paddle.expand(x=l_mask, shape=[b, 2 * c, h, w])
direction_diff = l_direction_split - f_direction_split
abs_direction_diff = paddle.abs(direction_diff)
direction_sign = abs_direction_diff < 1.0
direction_sign = paddle.cast(direction_sign, dtype='float32')
direction_sign.stop_gradient = True
direction_in_loss = 0.5 * abs_direction_diff * abs_direction_diff * direction_sign + \
(abs_direction_diff - 0.5) * (1.0 - direction_sign)
direction_out_loss = l_direction_norm_split * direction_in_loss
direction_loss = paddle.sum(direction_out_loss * l_direction_score * l_direction_mask) / \
(paddle.sum(l_direction_score * l_direction_mask) + 1e-5)
return direction_loss
def ctcloss(self, f_char, tcl_pos, tcl_mask, tcl_label, label_t):
f_char = paddle.transpose(f_char, [0, 2, 3, 1])
tcl_pos = paddle.reshape(tcl_pos, [-1, 3])
tcl_pos = paddle.cast(tcl_pos, dtype=int)
f_tcl_char = paddle.gather_nd(f_char, tcl_pos)
f_tcl_char = paddle.reshape(f_tcl_char,
[-1, 64, 37]) # len(Lexicon_Table)+1
f_tcl_char_fg, f_tcl_char_bg = paddle.split(f_tcl_char, [36, 1], axis=2)
f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0
b, c, l = tcl_mask.shape
tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, 36 * l])
tcl_mask_fg.stop_gradient = True
f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * (
-20.0)
f_tcl_char_mask = paddle.concat([f_tcl_char_fg, f_tcl_char_bg], axis=2)
f_tcl_char_ld = paddle.transpose(f_tcl_char_mask, (1, 0, 2))
N, B, _ = f_tcl_char_ld.shape
input_lengths = paddle.to_tensor([N] * B, dtype='int64')
cost = paddle.nn.functional.ctc_loss(
log_probs=f_tcl_char_ld,
labels=tcl_label,
input_lengths=input_lengths,
label_lengths=label_t,
blank=36,
reduction='none')
cost = cost.mean()
return cost
def forward(self, predicts, labels):
images, tcl_maps, tcl_label_maps, border_maps \
, direction_maps, training_masks, label_list, pos_list, pos_mask = labels
# for all the batch_size
pos_list, pos_mask, label_list, label_t = self.pre_process(
label_list, pos_list, pos_mask)
f_score, f_boder, f_direction, f_char = predicts
score_loss = self.dice_loss(f_score, tcl_maps, training_masks)
border_loss = self.border_loss(f_boder, border_maps, tcl_maps,
training_masks)
direction_loss = self.direction_loss(f_direction, direction_maps,
tcl_maps, training_masks)
ctc_loss = self.ctcloss(f_char, pos_list, pos_mask, label_list, label_t)
loss_all = score_loss + border_loss + direction_loss + 5 * ctc_loss
losses = {
'loss': loss_all,
"score_loss": score_loss,
"border_loss": border_loss,
"direction_loss": direction_loss,
"ctc_loss": ctc_loss
}
return losses
......@@ -26,8 +26,9 @@ def build_metric(config):
from .det_metric import DetMetric
from .rec_metric import RecMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
support_dict = ['DetMetric', 'RecMetric', 'ClsMetric']
support_dict = ['DetMetric', 'RecMetric', 'ClsMetric', 'E2EMetric']
config = copy.deepcopy(config)
module_name = config.pop('name')
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
__all__ = ['E2EMetric']
from ppocr.utils.e2e_metric.Deteval import *
class E2EMetric(object):
def __init__(self, main_indicator='f_score_e2e', **kwargs):
self.label_list = [
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C',
'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'
]
self.main_indicator = main_indicator
self.reset()
def __call__(self, preds, batch, **kwargs):
'''
batch: a list produced by dataloaders.
image: np.ndarray of shape (N, C, H, W).
ratio_list: np.ndarray of shape(N,2)
polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
preds: a list of dict produced by post process
points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
'''
gt_polyons_batch = batch[2]
temp_gt_strs_batch = batch[3]
ignore_tags_batch = batch[4]
gt_strs_batch = []
temp_gt_strs_batch = temp_gt_strs_batch[0].tolist()
for temp_list in temp_gt_strs_batch:
t = ""
for index in temp_list:
if index < 36:
t += self.label_list[index]
gt_strs_batch.append(t)
for pred, gt_polyons, gt_strs, ignore_tags in zip(
preds, gt_polyons_batch, gt_strs_batch, ignore_tags_batch):
# prepare gt
gt_info_list = [{
'points': gt_polyon,
'text': gt_str,
'ignore': ignore_tag
} for gt_polyon, gt_str, ignore_tag in
zip(gt_polyons, gt_strs, ignore_tags)]
# prepare det
e2e_info_list = [{
'points': det_polyon,
'text': pred_str
} for det_polyon, pred_str in zip(pred['points'], preds['strs'])]
result = get_socre(gt_info_list, e2e_info_list)
self.results.append(result)
def get_metric(self):
"""
return metrics {
'precision': 0,
'recall': 0,
'hmean': 0
}
"""
metircs = combine_results(self.results)
self.reset()
return metircs
def reset(self):
self.results = [] # clear results
......@@ -150,7 +150,7 @@ class DetectionIoUEvaluator(object):
pairs.append({'gt': gtNum, 'det': detNum})
detMatchedNums.append(detNum)
evaluationLog += "Match GT #" + \
str(gtNum) + " with Det #" + str(detNum) + "\n"
str(gtNum) + " with Det #" + str(detNum) + "\n"
numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
numDetCare = (len(detPols) - len(detDontCarePolsNum))
......@@ -162,7 +162,7 @@ class DetectionIoUEvaluator(object):
precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
hmean = 0 if (precision + recall) == 0 else 2.0 * \
precision * recall / (precision + recall)
precision * recall / (precision + recall)
matchedSum += detMatched
numGlobalCareGt += numGtCare
......@@ -200,7 +200,8 @@ class DetectionIoUEvaluator(object):
methodPrecision = 0 if numGlobalCareDet == 0 else float(
matchedSum) / numGlobalCareDet
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
methodRecall * methodPrecision / (methodRecall + methodPrecision)
methodRecall * methodPrecision / (
methodRecall + methodPrecision)
# print(methodRecall, methodPrecision, methodHmean)
# sys.exit(-1)
methodMetrics = {
......
......@@ -26,6 +26,9 @@ def build_backbone(config, model_type):
from .rec_resnet_vd import ResNet
from .rec_resnet_fpn import ResNetFPN
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
support_dict = ['ResNet']
else:
raise NotImplementedError
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
__all__ = ["ResNet"]
class ConvBNLayer(nn.Layer):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
is_vd_mode=False,
act=None,
name=None, ):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = nn.BatchNorm(
out_channels,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def forward(self, inputs):
# if self.is_vd_mode:
# inputs = self._pool2d_avg(inputs)
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class BottleneckBlock(nn.Layer):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels * 4,
kernel_size=1,
stride=stride,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv2)
y = F.relu(y)
return y
class BasicBlock(nn.Layer):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv1)
y = F.relu(y)
return y
class ResNet(nn.Layer):
def __init__(self, in_channels=3, layers=50, **kwargs):
super(ResNet, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
# depth = [3, 4, 6, 3]
depth = [3, 4, 6, 3, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_channels = [64, 256, 512, 1024,
2048] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512, 512]
self.conv1_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
kernel_size=7,
stride=2,
act='relu',
name="conv1_1")
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stages = []
self.out_channels = [3, 64]
# num_filters = [64, 128, 256, 512, 512]
if layers >= 50:
for block in range(len(depth)):
block_list = []
shortcut = False
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name))
shortcut = True
block_list.append(bottleneck_block)
self.out_channels.append(num_filters[block] * 4)
self.stages.append(nn.Sequential(*block_list))
else:
for block in range(len(depth)):
block_list = []
shortcut = False
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BasicBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block],
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name))
shortcut = True
block_list.append(basic_block)
self.out_channels.append(num_filters[block])
self.stages.append(nn.Sequential(*block_list))
def forward(self, inputs):
out = [inputs]
y = self.conv1_1(inputs)
out.append(y)
y = self.pool2d_max(y)
for block in self.stages:
y = block(y)
out.append(y)
return out
......@@ -20,6 +20,7 @@ def build_head(config):
from .det_db_head import DBHead
from .det_east_head import EASTHead
from .det_sast_head import SASTHead
from .e2e_pg_head import PGHead
# rec head
from .rec_ctc_head import CTCHead
......@@ -30,8 +31,8 @@ def build_head(config):
from .cls_head import ClsHead
support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
'SRNHead'
]
'SRNHead', 'PGHead']
module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format(
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=1,
if_act=True,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
weight_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=act,
param_attr=ParamAttr(name="bn_" + name + "_scale"),
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
moving_variance_name="bn_" + name + "_variance",
use_global_stats=False)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class PGHead(nn.Layer):
"""
"""
def __init__(self, in_channels, model_name, **kwargs):
super(PGHead, self).__init__()
self.model_name = model_name
self.conv_f_score1 = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_score{}".format(1))
self.conv_f_score2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_score{}".format(2))
self.conv_f_score3 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_score{}".format(3))
self.conv1 = nn.Conv2D(
in_channels=128,
out_channels=1,
kernel_size=3,
stride=1,
padding=1,
groups=1,
weight_attr=ParamAttr(name="conv_f_score{}".format(4)),
bias_attr=False)
self.conv_f_boder1 = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_boder{}".format(1))
self.conv_f_boder2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_boder{}".format(2))
self.conv_f_boder3 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_boder{}".format(3))
self.conv2 = nn.Conv2D(
in_channels=128,
out_channels=4,
kernel_size=3,
stride=1,
padding=1,
groups=1,
weight_attr=ParamAttr(name="conv_f_boder{}".format(4)),
bias_attr=False)
self.conv_f_char1 = ConvBNLayer(
in_channels=in_channels,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_char{}".format(1))
self.conv_f_char2 = ConvBNLayer(
in_channels=128,
out_channels=128,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_char{}".format(2))
self.conv_f_char3 = ConvBNLayer(
in_channels=128,
out_channels=256,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_char{}".format(3))
self.conv_f_char4 = ConvBNLayer(
in_channels=256,
out_channels=256,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_char{}".format(4))
self.conv_f_char5 = ConvBNLayer(
in_channels=256,
out_channels=256,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_char{}".format(5))
self.conv3 = nn.Conv2D(
in_channels=256,
out_channels=6625,
kernel_size=3,
stride=1,
padding=1,
groups=1,
weight_attr=ParamAttr(name="conv_f_char{}".format(6)),
bias_attr=False)
self.conv_f_direc1 = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_direc{}".format(1))
self.conv_f_direc2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_direc{}".format(2))
self.conv_f_direc3 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_direc{}".format(3))
self.conv4 = nn.Conv2D(
in_channels=128,
out_channels=2,
kernel_size=3,
stride=1,
padding=1,
groups=1,
weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
bias_attr=False)
def forward(self, x):
f_score = self.conv_f_score1(x)
f_score = self.conv_f_score2(f_score)
f_score = self.conv_f_score3(f_score)
f_score = self.conv1(f_score)
f_score = F.sigmoid(f_score)
# f_boder
f_boder = self.conv_f_boder1(x)
f_boder = self.conv_f_boder2(f_boder)
f_boder = self.conv_f_boder3(f_boder)
f_boder = self.conv2(f_boder)
f_char = self.conv_f_char1(x)
f_char = self.conv_f_char2(f_char)
f_char = self.conv_f_char3(f_char)
f_char = self.conv_f_char4(f_char)
f_char = self.conv_f_char5(f_char)
f_char = self.conv3(f_char)
f_direction = self.conv_f_direc1(x)
f_direction = self.conv_f_direc2(f_direction)
f_direction = self.conv_f_direc3(f_direction)
f_direction = self.conv4(f_direction)
return f_score, f_boder, f_direction, f_char
......@@ -14,12 +14,14 @@
__all__ = ['build_neck']
def build_neck(config):
from .db_fpn import DBFPN
from .east_fpn import EASTFPN
from .sast_fpn import SASTFPN
from .rnn import SequenceEncoder
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder']
from .pg_fpn import PGFPN
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN']
module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format(
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
is_vd_mode=False,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = nn.BatchNorm(
out_channels,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',
use_global_stats=False)
def forward(self, inputs):
# if self.is_vd_mode:
# inputs = self._pool2d_avg(inputs)
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class DeConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size=4,
stride=2,
padding=1,
groups=1,
if_act=True,
act=None,
name=None):
super(DeConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.deconv = nn.Conv2DTranspose(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
weight_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=act,
param_attr=ParamAttr(name="bn_" + name + "_scale"),
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
moving_variance_name="bn_" + name + "_variance",
use_global_stats=False)
def forward(self, x):
x = self.deconv(x)
x = self.bn(x)
return x
class FPN_Up_Fusion(nn.Layer):
def __init__(self, in_channels):
super(FPN_Up_Fusion, self).__init__()
in_channels = in_channels[::-1]
out_channels = [256, 256, 192, 192, 128]
self.h0_conv = ConvBNLayer(
in_channels[0], out_channels[0], 1, 1, act=None, name='conv_h0')
self.h1_conv = ConvBNLayer(
in_channels[1], out_channels[1], 1, 1, act=None, name='conv_h1')
self.h2_conv = ConvBNLayer(
in_channels[2], out_channels[2], 1, 1, act=None, name='conv_h2')
self.h3_conv = ConvBNLayer(
in_channels[3], out_channels[3], 1, 1, act=None, name='conv_h3')
self.h4_conv = ConvBNLayer(
in_channels[4], out_channels[4], 1, 1, act=None, name='conv_h4')
self.dconv0 = DeConvBNLayer(
in_channels=out_channels[0],
out_channels=out_channels[1],
name="dconv_{}".format(0))
self.dconv1 = DeConvBNLayer(
in_channels=out_channels[1],
out_channels=out_channels[2],
act=None,
name="dconv_{}".format(1))
self.dconv2 = DeConvBNLayer(
in_channels=out_channels[2],
out_channels=out_channels[3],
act=None,
name="dconv_{}".format(2))
self.dconv3 = DeConvBNLayer(
in_channels=out_channels[3],
out_channels=out_channels[4],
act=None,
name="dconv_{}".format(3))
self.conv_g1 = ConvBNLayer(
in_channels=out_channels[1],
out_channels=out_channels[1],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(1))
self.conv_g2 = ConvBNLayer(
in_channels=out_channels[2],
out_channels=out_channels[2],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(2))
self.conv_g3 = ConvBNLayer(
in_channels=out_channels[3],
out_channels=out_channels[3],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(3))
self.conv_g4 = ConvBNLayer(
in_channels=out_channels[4],
out_channels=out_channels[4],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(4))
self.convf = ConvBNLayer(
in_channels=out_channels[4],
out_channels=out_channels[4],
kernel_size=1,
stride=1,
act=None,
name="conv_f{}".format(4))
def _add_relu(self, x1, x2):
x = paddle.add(x=x1, y=x2)
x = F.relu(x)
return x
def forward(self, x):
f = x[2:][::-1]
h0 = self.h0_conv(f[0])
h1 = self.h1_conv(f[1])
h2 = self.h2_conv(f[2])
h3 = self.h3_conv(f[3])
h4 = self.h4_conv(f[4])
g0 = self.dconv0(h0)
g1 = self.dconv2(self.conv_g2(self._add_relu(g0, h1)))
g2 = self.dconv2(self.conv_g2(self._add_relu(g1, h2)))
g3 = self.dconv3(self.conv_g2(self._add_relu(g2, h3)))
g4 = self.dconv4(self.conv_g2(self._add_relu(g3, h4)))
return g4
class FPN_Down_Fusion(nn.Layer):
def __init__(self, in_channels):
super(FPN_Down_Fusion, self).__init__()
out_channels = [32, 64, 128]
self.h0_conv = ConvBNLayer(
in_channels[0], out_channels[0], 3, 1, act=None, name='FPN_d1')
self.h1_conv = ConvBNLayer(
in_channels[1], out_channels[1], 3, 1, act=None, name='FPN_d2')
self.h2_conv = ConvBNLayer(
in_channels[2], out_channels[2], 3, 1, act=None, name='FPN_d3')
self.g0_conv = ConvBNLayer(
out_channels[0], out_channels[1], 3, 2, act=None, name='FPN_d4')
self.g1_conv = nn.Sequential(
ConvBNLayer(
out_channels[1],
out_channels[1],
3,
1,
act='relu',
name='FPN_d5'),
ConvBNLayer(
out_channels[1], out_channels[2], 3, 2, act=None,
name='FPN_d6'))
self.g2_conv = nn.Sequential(
ConvBNLayer(
out_channels[2],
out_channels[2],
3,
1,
act='relu',
name='FPN_d7'),
ConvBNLayer(
out_channels[2], out_channels[2], 1, 1, act=None,
name='FPN_d8'))
def forward(self, x):
f = x[:3]
h0 = self.h0_conv(f[0])
h1 = self.h1_conv(f[1])
h2 = self.h2_conv(f[2])
g0 = self.g0_conv(h0)
g1 = paddle.add(x=g0, y=h1)
g1 = F.relu(g1)
g1 = self.g1_conv(g1)
g2 = paddle.add(x=g1, y=h2)
g2 = F.relu(g2)
g2 = self.g2_conv(g2)
return g2
class PGFPN(nn.Layer):
def __init__(self, in_channels, with_cab=False, **kwargs):
super(PGFPN, self).__init__()
self.in_channels = in_channels
self.with_cab = with_cab
self.FPN_Down_Fusion = FPN_Down_Fusion(self.in_channels)
self.FPN_Up_Fusion = FPN_Up_Fusion(self.in_channels)
self.out_channels = 128
def forward(self, x):
# down fpn
f_down = self.FPN_Down_Fusion(x)
# up fpn
f_up = self.FPN_Up_Fusion(x)
# fusion
f_common = paddle.add(x=f_down, y=f_up)
f_common = F.relu(f_common)
return f_common
......@@ -28,10 +28,11 @@ def build_post_process(config, global_config=None):
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode'
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess'
]
config = copy.deepcopy(config)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
import numpy as np
from .locality_aware_nms import nms_locality
from ppocr.utils.e2e_utils.extract_textpoint import *
from ppocr.utils.e2e_utils.ski_thin import *
from ppocr.utils.e2e_utils.visual import *
import paddle
import cv2
import time
class PGPostProcess(object):
"""
The post process for SAST.
"""
def __init__(self,
score_thresh=0.5,
nms_thresh=0.2,
sample_pts_num=2,
shrink_ratio_of_width=0.3,
expand_scale=1.0,
tcl_map_thresh=0.5,
**kwargs):
self.result_path = ""
self.valid_set = 'totaltext'
self.Lexicon_Table = [
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C',
'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'
]
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.sample_pts_num = sample_pts_num
self.shrink_ratio_of_width = shrink_ratio_of_width
self.expand_scale = expand_scale
self.tcl_map_thresh = tcl_map_thresh
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
if sys.version_info.major == 3 and sys.version_info.minor == 5:
self.is_python35 = True
def __call__(self, outs_dict, shape_list):
p_score, p_border, p_direction, p_char = outs_dict[:4]
p_score = p_score[0].numpy()
p_border = p_border[0].numpy()
p_direction = p_direction[0].numpy()
p_char = p_char[0].numpy()
src_h, src_w, ratio_h, ratio_w = shape_list[0]
if self.valid_set != 'totaltext':
is_curved = False
else:
is_curved = True
instance_yxs_list = generate_pivot_list(
p_score,
p_char,
p_direction,
score_thresh=self.score_thresh,
is_backbone=True,
is_curved=is_curved)
p_char = np.expand_dims(p_char, axis=0)
p_char = paddle.to_tensor(p_char)
char_seq_idx_set = []
for i in range(len(instance_yxs_list)):
gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
featyre_seq = paddle.gather_nd(f_char_map, gather_info_lod)
featyre_seq = np.expand_dims(featyre_seq.numpy(), axis=0)
t = len(featyre_seq[0])
featyre_seq = paddle.to_tensor(featyre_seq)
l = np.array([[t]]).astype(np.int64)
length = paddle.to_tensor(l)
seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
input=featyre_seq, blank=36, input_length=length)
seq_pred1 = seq_pred[0].numpy().tolist()[0]
seq_len = seq_pred[1].numpy()[0][0]
temp_t = []
for x in seq_pred1[:seq_len]:
temp_t.append(x)
char_seq_idx_set.append(temp_t)
seq_strs = []
for char_idx_set in char_seq_idx_set:
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
seq_strs.append(pr_str)
poly_list = []
keep_str_list = []
all_point_list = []
all_point_pair_list = []
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
if len(yx_center_line) == 1:
print('the length of tcl point is less than 2, repeat')
yx_center_line.append(yx_center_line[-1])
# expand corresponding offset for total-text.
offset_expand = 1.0
if self.valid_set == 'totaltext':
offset_expand = 1.2
point_pair_list = []
for batch_id, y, x in yx_center_line:
offset = p_border[:, y, x].reshape(2, 2)
if offset_expand != 1.0:
offset_length = np.linalg.norm(
offset, axis=1, keepdims=True)
expand_length = np.clip(
offset_length * (offset_expand - 1),
a_min=0.5,
a_max=3.0)
offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
# for visualization
all_point_list.append([
int(round(x * 4.0 / ratio_w)),
int(round(y * 4.0 / ratio_h))
])
all_point_pair_list.append(point_pair.round().astype(np.int32)
.tolist())
# ndarry: (x, 2)
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
print('expand along width. {}'.format(detected_poly.shape))
detected_poly = expand_poly_along_width(
detected_poly, shrink_ratio_of_width=0.2)
detected_poly[:, 0] = np.clip(
detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(
detected_poly[:, 1], a_min=0, a_max=src_h)
if len(keep_str) < 2:
print('--> too short, {}'.format(keep_str))
continue
keep_str_list.append(keep_str)
if self.valid_set == 'partvgg':
middle_point = len(detected_poly) // 2
detected_poly = detected_poly[
[0, middle_point - 1, middle_point, -1], :]
poly_list.append(detected_poly)
elif self.valid_set == 'totaltext':
poly_list.append(detected_poly)
else:
print('--> Not supported format.')
exit(-1)
data = {
'points': poly_list,
'strs': keep_str_list,
}
# visualization
# if self.save_visualization:
# visualize_e2e_result(im_fn, poly_list, keep_str_list, src_im)
# visualize_point_result(im_fn, all_point_list, all_point_pair_list, src_im)
# save detected boxes
# txt_dir = (result_path[:-1] if result_path.endswith('/') else result_path) + '_txt_anno'
# if not os.path.exists(txt_dir):
# os.makedirs(txt_dir)
# res_file = os.path.join(txt_dir, '{}.txt'.format(im_prefix))
# with open(res_file, 'w') as f:
# for i_box, box in enumerate(poly_list):
# seq_str = keep_str_list[i_box]
# box = np.round(box).astype('int32')
# box_str = ','.join(str(s) for s in (box.flatten().tolist()))
# f.write('{}\t{}\r\n'.format(box_str, seq_str))
return data
......@@ -18,6 +18,7 @@ from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
......@@ -49,12 +50,12 @@ class SASTPostProcess(object):
self.shrink_ratio_of_width = shrink_ratio_of_width
self.expand_scale = expand_scale
self.tcl_map_thresh = tcl_map_thresh
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
if sys.version_info.major == 3 and sys.version_info.minor == 5:
self.is_python35 = True
def point_pair2poly(self, point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
......@@ -66,31 +67,42 @@ class SASTPostProcess(object):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2)
def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.):
def shrink_quad_along_width(self,
quad,
begin_width_ratio=0.,
end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array([poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]], dtype=np.float32)
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio,
1.0)
right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + \
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio)
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0,
right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
......@@ -100,7 +112,7 @@ class SASTPostProcess(object):
def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
"""Restore quad."""
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
xy_text = xy_text[:, ::-1] # (n, 2)
xy_text = xy_text[:, ::-1] # (n, 2)
# Sort the text boxes via the y axis
xy_text = xy_text[np.argsort(xy_text[:, 1])]
......@@ -112,7 +124,7 @@ class SASTPostProcess(object):
point_num = int(tvo_map.shape[-1] / 2)
assert point_num == 4
tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
quads = xy_text_tile - tvo_map
return scores, quads, xy_text
......@@ -121,14 +133,12 @@ class SASTPostProcess(object):
"""
compute area of a quad.
"""
edge = [
(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
(quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
(quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])
]
edge = [(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
(quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
(quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])]
return np.sum(edge) / 2.
def nms(self, dets):
if self.is_python35:
import lanms
......@@ -141,7 +151,7 @@ class SASTPostProcess(object):
"""
Cluster pixels in tcl_map based on quads.
"""
instance_count = quads.shape[0] + 1 # contain background
instance_count = quads.shape[0] + 1 # contain background
instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
if instance_count == 1:
return instance_count, instance_label_map
......@@ -149,18 +159,19 @@ class SASTPostProcess(object):
# predict text center
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
n = xy_text.shape[0]
xy_text = xy_text[:, ::-1] # (n, 2)
tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
xy_text = xy_text[:, ::-1] # (n, 2)
tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
pred_tc = xy_text - tco
# get gt text center
m = quads.shape[0]
gt_tc = np.mean(quads, axis=1) # (m, 2)
gt_tc = np.mean(quads, axis=1) # (m, 2)
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2)
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :],
(1, m, 1)) # (n, m, 2)
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
return instance_count, instance_label_map
......@@ -169,26 +180,47 @@ class SASTPostProcess(object):
"""
Estimate sample points number.
"""
eh = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) / 2.0
ew = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0
eh = (np.linalg.norm(quad[0] - quad[3]) +
np.linalg.norm(quad[1] - quad[2])) / 2.0
ew = (np.linalg.norm(quad[0] - quad[1]) +
np.linalg.norm(quad[2] - quad[3])) / 2.0
dense_sample_pts_num = max(2, int(ew))
dense_xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, dense_sample_pts_num,
endpoint=True, dtype=np.float32).astype(np.int32)]
dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1]
estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1))
dense_xy_center_line = xy_text[np.linspace(
0,
xy_text.shape[0] - 1,
dense_sample_pts_num,
endpoint=True,
dtype=np.float32).astype(np.int32)]
dense_xy_center_line_diff = dense_xy_center_line[
1:] - dense_xy_center_line[:-1]
estimate_arc_len = np.sum(
np.linalg.norm(
dense_xy_center_line_diff, axis=1))
sample_pts_num = max(2, int(estimate_arc_len / eh))
return sample_pts_num
def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_w, src_h,
shrink_ratio_of_width=0.3, tcl_map_thresh=0.5, offset_expand=1.0, out_strid=4.0):
def detect_sast(self,
tcl_map,
tvo_map,
tbo_map,
tco_map,
ratio_w,
ratio_h,
src_w,
src_h,
shrink_ratio_of_width=0.3,
tcl_map_thresh=0.5,
offset_expand=1.0,
out_strid=4.0):
"""
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
"""
# restore quad
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map)
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh,
tvo_map)
dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
dets = self.nms(dets)
if dets.shape[0] == 0:
......@@ -202,7 +234,8 @@ class SASTPostProcess(object):
# instance segmentation
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
instance_count, instance_label_map = self.cluster_by_quads_tco(tcl_map, tcl_map_thresh, quads, tco_map)
instance_count, instance_label_map = self.cluster_by_quads_tco(
tcl_map, tcl_map_thresh, quads, tco_map)
# restore single poly with tcl instance.
poly_list = []
......@@ -212,10 +245,10 @@ class SASTPostProcess(object):
q_area = quad_areas[instance_idx - 1]
if q_area < 5:
continue
#
len1 = float(np.linalg.norm(quad[0] -quad[1]))
len2 = float(np.linalg.norm(quad[1] -quad[2]))
len1 = float(np.linalg.norm(quad[0] - quad[1]))
len2 = float(np.linalg.norm(quad[1] - quad[2]))
min_len = min(len1, len2)
if min_len < 3:
continue
......@@ -225,16 +258,18 @@ class SASTPostProcess(object):
continue
# filter low confidence instance
xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
continue
# sort xy_text
left_center_pt = np.array([[(quad[0, 0] + quad[-1, 0]) / 2.0,
(quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
right_center_pt = np.array([[(quad[1, 0] + quad[2, 0]) / 2.0,
(quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
left_center_pt = np.array(
[[(quad[0, 0] + quad[-1, 0]) / 2.0,
(quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
right_center_pt = np.array(
[[(quad[1, 0] + quad[2, 0]) / 2.0,
(quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
proj_unit_vec = (right_center_pt - left_center_pt) / \
(np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
......@@ -245,33 +280,45 @@ class SASTPostProcess(object):
sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
else:
sample_pts_num = self.sample_pts_num
xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, sample_pts_num,
endpoint=True, dtype=np.float32).astype(np.int32)]
xy_center_line = xy_text[np.linspace(
0,
xy_text.shape[0] - 1,
sample_pts_num,
endpoint=True,
dtype=np.float32).astype(np.int32)]
point_pair_list = []
for x, y in xy_center_line:
# get corresponding offset
offset = tbo_map[y, x, :].reshape(2, 2)
if offset_expand != 1.0:
offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
expand_length = np.clip(offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0)
offset_length = np.linalg.norm(
offset, axis=1, keepdims=True)
expand_length = np.clip(
offset_length * (offset_expand - 1),
a_min=0.5,
a_max=3.0)
offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal
# original point
offset = offset + offset_detal
# original point
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1]* out_strid / np.array([ratio_w, ratio_h]).reshape(-1, 2)
point_pair = (ori_yx + offset)[:, ::-1] * out_strid / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
# ndarry: (x, 2), expand poly along width
detected_poly = self.point_pair2poly(point_pair_list)
detected_poly = self.expand_poly_along_width(detected_poly, shrink_ratio_of_width)
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
detected_poly = self.expand_poly_along_width(detected_poly,
shrink_ratio_of_width)
detected_poly[:, 0] = np.clip(
detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(
detected_poly[:, 1], a_min=0, a_max=src_h)
poly_list.append(detected_poly)
return poly_list
def __call__(self, outs_dict, shape_list):
def __call__(self, outs_dict, shape_list):
score_list = outs_dict['f_score']
border_list = outs_dict['f_border']
tvo_list = outs_dict['f_tvo']
......@@ -281,20 +328,28 @@ class SASTPostProcess(object):
border_list = border_list.numpy()
tvo_list = tvo_list.numpy()
tco_list = tco_list.numpy()
img_num = len(shape_list)
poly_lists = []
for ino in range(img_num):
p_score = score_list[ino].transpose((1,2,0))
p_border = border_list[ino].transpose((1,2,0))
p_tvo = tvo_list[ino].transpose((1,2,0))
p_tco = tco_list[ino].transpose((1,2,0))
p_score = score_list[ino].transpose((1, 2, 0))
p_border = border_list[ino].transpose((1, 2, 0))
p_tvo = tvo_list[ino].transpose((1, 2, 0))
p_tco = tco_list[ino].transpose((1, 2, 0))
src_h, src_w, ratio_h, ratio_w = shape_list[ino]
poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h,
shrink_ratio_of_width=self.shrink_ratio_of_width,
tcl_map_thresh=self.tcl_map_thresh, offset_expand=self.expand_scale)
poly_list = self.detect_sast(
p_score,
p_tvo,
p_border,
p_tco,
ratio_w,
ratio_h,
src_w,
src_h,
shrink_ratio_of_width=self.shrink_ratio_of_width,
tcl_map_thresh=self.tcl_map_thresh,
offset_expand=self.expand_scale)
poly_lists.append({'points': np.array(poly_list)})
return poly_lists
from os import listdir
import os, sys
from scipy import io
import numpy as np
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
from tqdm import tqdm
try: # python2
range = xrange
except Exception:
# python3
range = range
"""
Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('\n')'
"""
# if len(sys.argv) != 4:
# print('\n usage: test.py pred_dir gt_dir savefile')
# sys.exit()
def get_socre(gt_dict, pred_dict):
# allInputs = listdir(input_dir)
allInputs = 1
def input_reading_mod(pred_dict, input):
"""This helper reads input from txt files"""
det = []
n = len(pred_dict)
for i in range(n):
points = pred_dict[i]['points']
text = pred_dict[i]['text']
# for i in range(len(points)):
point = ",".join(map(str, points.reshape(-1, )))
det.append([point, text])
return det
def gt_reading_mod(gt_dict, gt_id):
"""This helper reads groundtruths from mat files"""
# gt_id = gt_id.split('.')[0]
gt = []
n = len(gt_dict)
for i in range(n):
points = gt_dict[i]['points'].tolist()
h = len(points)
text = gt_dict[i]['text']
xx = [
np.array(
['x:'], dtype='<U2'), 0, np.array(
['y:'], dtype='<U2'), 0, np.array(
['#'], dtype='<U1'), np.array(
['#'], dtype='<U1')
]
t_x, t_y = [], []
for j in range(h):
t_x.append(points[j][0])
t_y.append(points[j][1])
xx[1] = np.array([t_x], dtype='int16')
xx[3] = np.array([t_y], dtype='int16')
if text != "":
xx[4] = np.array([text], dtype='U{}'.format(len(text)))
xx[5] = np.array(['c'], dtype='<U1')
gt.append(xx)
return gt
def detection_filtering(detections, groundtruths, threshold=0.5):
for gt_id, gt in enumerate(groundtruths):
print
"liushanshan gt[1] = {}".format(gt[1])
print
"liushanshan gt[2] = {}".format(gt[2])
print
"liushanshan gt[3] = {}".format(gt[3])
print
"liushanshan gt[4] = {}".format(gt[4])
print
"liushanshan gt[5] = {}".format(gt[5])
if (gt[5] == '#') and (gt[1].shape[1] > 1):
gt_x = list(map(int, np.squeeze(gt[1])))
gt_y = list(map(int, np.squeeze(gt[3])))
for det_id, detection in enumerate(detections):
detection_orig = detection
detection = [float(x) for x in detection[0].split(',')]
# detection = detection.split(',')
detection = list(map(int, detection))
det_x = detection[0::2]
det_y = detection[1::2]
det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
if det_gt_iou > threshold:
detections[det_id] = []
detections[:] = [item for item in detections if item != []]
return detections
def sigma_calculation(det_x, det_y, gt_x, gt_y):
"""
sigma = inter_area / gt_area
"""
# print(area_of_intersection(det_x, det_y, gt_x, gt_y))
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
area(gt_x, gt_y)), 2)
def tau_calculation(det_x, det_y, gt_x, gt_y):
"""
tau = inter_area / det_area
"""
# print "liushanshan det_x {}".format(det_x)
# print "liushanshan det_y {}".format(det_y)
# print "liushanshan area {}".format(area(det_x, det_y))
# print "liushanshan tau = {}".format(np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / area(det_x, det_y)), 2))
if area(det_x, det_y) == 0.0:
return 0
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
area(det_x, det_y)), 2)
##############################Initialization###################################
global_tp = 0
global_fp = 0
global_fn = 0
global_sigma = []
global_tau = []
tr = 0.7
tp = 0.6
fsc_k = 0.8
k = 2
global_pred_str = []
global_gt_str = []
###############################################################################
for input_id in range(allInputs):
if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
and (input_id != 'Deteval_result_non_curved.txt'):
print(input_id)
detections = input_reading_mod(pred_dict, input_id)
# print "liushanshan detections = {}".format(detections)
groundtruths = gt_reading_mod(gt_dict, input_id)
detections = detection_filtering(
detections,
groundtruths) # filters detections overlapping with DC area
dc_id = []
for i in range(len(groundtruths)):
if groundtruths[i][5] == '#':
dc_id.append(i)
cnt = 0
for a in dc_id:
num = a - cnt
del groundtruths[num]
cnt += 1
local_sigma_table = np.zeros((len(groundtruths), len(detections)))
local_tau_table = np.zeros((len(groundtruths), len(detections)))
local_pred_str = {}
local_gt_str = {}
for gt_id, gt in enumerate(groundtruths):
if len(detections) > 0:
for det_id, detection in enumerate(detections):
detection_orig = detection
detection = [float(x) for x in detection[0].split(',')]
detection = list(map(int, detection))
pred_seq_str = detection_orig[1].strip()
det_x = detection[0::2]
det_y = detection[1::2]
gt_x = list(map(int, np.squeeze(gt[1])))
gt_y = list(map(int, np.squeeze(gt[3])))
gt_seq_str = str(gt[4].tolist()[0])
local_sigma_table[gt_id, det_id] = sigma_calculation(
det_x, det_y, gt_x, gt_y)
local_tau_table[gt_id, det_id] = tau_calculation(
det_x, det_y, gt_x, gt_y)
local_pred_str[det_id] = pred_seq_str
local_gt_str[gt_id] = gt_seq_str
global_sigma.append(local_sigma_table)
global_tau.append(local_tau_table)
global_pred_str.append(local_pred_str)
global_gt_str.append(local_gt_str)
print
"liushanshan global_pred_str = {}".format(global_pred_str)
print
"liushanshan global_gt_str = {}".format(global_gt_str)
global_accumulative_recall = 0
global_accumulative_precision = 0
total_num_gt = 0
total_num_det = 0
hit_str_count = 0
hit_count = 0
def one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for gt_id in range(num_gt):
gt_matching_qualified_sigma_candidates = np.where(
local_sigma_table[gt_id, :] > tr)
gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[
0].shape[0]
gt_matching_qualified_tau_candidates = np.where(
local_tau_table[gt_id, :] > tp)
gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[
0].shape[0]
det_matching_qualified_sigma_candidates = np.where(
local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
> tr)
det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[
0].shape[0]
det_matching_qualified_tau_candidates = np.where(
local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
tp)
det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[
0].shape[0]
if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
(det_matching_num_qualified_sigma_candidates == 1) and (
det_matching_num_qualified_tau_candidates == 1):
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, gt_id] = 1
matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
# recg start
print
"liushanshan one to one det_id = {}".format(matched_det_id)
print
"liushanshan one to one gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[
0]]
print
"liushanshan one to one gt_str_cur = {}".format(gt_str_cur)
print
"liushanshan one to one pred_str_cur = {}".format(pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
det_flag[0, matched_det_id] = 1
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
def one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for gt_id in range(num_gt):
# skip the following if the groundtruth was matched
if gt_flag[0, gt_id] > 0:
continue
non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
if num_non_zero_in_sigma >= k:
####search for all detections that overlaps with this groundtruth
qualified_tau_candidates = np.where((local_tau_table[
gt_id, :] >= tp) & (det_flag[0, :] == 0))
num_qualified_tau_candidates = qualified_tau_candidates[
0].shape[0]
if num_qualified_tau_candidates == 1:
if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp)
and
(local_sigma_table[gt_id, qualified_tau_candidates] >=
tr)):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
print
"liushanshan one to many det_id = {}".format(
qualified_tau_candidates)
print
"liushanshan one to many gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]]
print
"liushanshan one to many gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan one to many pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
>= tr):
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
print
"liushanshan one to many det_id = {}".format(
qualified_tau_candidates)
print
"liushanshan one to many gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]]
print
"liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
print
"liushanshan one to many pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
global_accumulative_recall = global_accumulative_recall + fsc_k
global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
local_accumulative_recall = local_accumulative_recall + fsc_k
local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
def many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for det_id in range(num_det):
# skip the following if the detection was matched
if det_flag[0, det_id] > 0:
continue
non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
if num_non_zero_in_tau >= k:
####search for all detections that overlaps with this groundtruth
qualified_sigma_candidates = np.where((
local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
num_qualified_sigma_candidates = qualified_sigma_candidates[
0].shape[0]
if num_qualified_sigma_candidates == 1:
if ((local_tau_table[qualified_sigma_candidates, det_id] >=
tp) and
(local_sigma_table[qualified_sigma_candidates, det_id]
>= tr)):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, qualified_sigma_candidates] = 1
det_flag[0, det_id] = 1
# recg start
print
"liushanshan many to one det_id = {}".format(det_id)
print
"liushanshan many to one gt_id = {}".format(
qualified_sigma_candidates)
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[
idx]
if not global_gt_str[idy].has_key(ele_gt_id):
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
print
"liushanshan many to one gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan many to one pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
break
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
break
# recg end
elif (np.sum(local_tau_table[qualified_sigma_candidates,
det_id]) >= tp):
det_flag[0, det_id] = 1
gt_flag[0, qualified_sigma_candidates] = 1
# recg start
print
"liushanshan many to one det_id = {}".format(det_id)
print
"liushanshan many to one gt_id = {}".format(
qualified_sigma_candidates)
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
if not global_gt_str[idy].has_key(ele_gt_id):
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
print
"liushanshan many to one gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan many to one pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
break
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
break
else:
print
'no match'
# recg end
global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
global_accumulative_precision = global_accumulative_precision + fsc_k
local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
local_accumulative_precision = local_accumulative_precision + fsc_k
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
single_data = {}
for idx in range(len(global_sigma)):
# print(allInputs[idx])
local_sigma_table = global_sigma[idx]
local_tau_table = global_tau[idx]
num_gt = local_sigma_table.shape[0]
num_det = local_sigma_table.shape[1]
total_num_gt = total_num_gt + num_gt
total_num_det = total_num_det + num_det
local_accumulative_recall = 0
local_accumulative_precision = 0
gt_flag = np.zeros((1, num_gt))
det_flag = np.zeros((1, num_det))
#######first check for one-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
#######then check for one-to-many case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
#######then check for many-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
# fid = open(fid_path, 'a+')
try:
local_precision = local_accumulative_precision / num_det
except ZeroDivisionError:
local_precision = 0
try:
local_recall = local_accumulative_recall / num_gt
except ZeroDivisionError:
local_recall = 0
try:
local_f_score = 2 * local_precision * local_recall / (
local_precision + local_recall)
except ZeroDivisionError:
local_f_score = 0
# temp = ('%s: Recall=%.4f, Precision=%.4f, f_score=%.4f\n' % (
# allInputs[idx], local_recall, local_precision, local_f_score))
single_data['sigma'] = global_sigma
single_data['global_tau'] = global_tau
single_data['global_pred_str'] = global_pred_str
single_data['global_gt_str'] = global_gt_str
single_data["recall"] = local_recall
single_data['precision'] = local_precision
single_data['f_score'] = local_f_score
return single_data
def combine_results(all_data):
tr = 0.7
tp = 0.6
fsc_k = 0.8
k = 2
global_sigma = []
global_tau = []
global_pred_str = []
global_gt_str = []
for data in all_data:
global_sigma.append(data['sigma'][0])
global_tau.append(data['global_tau'][0])
global_pred_str.append(data['global_pred_str'][0])
global_gt_str.append(data['global_gt_str'][0])
global_accumulative_recall = 0
global_accumulative_precision = 0
total_num_gt = 0
total_num_det = 0
hit_str_count = 0
hit_count = 0
def one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for gt_id in range(num_gt):
gt_matching_qualified_sigma_candidates = np.where(
local_sigma_table[gt_id, :] > tr)
gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[
0].shape[0]
gt_matching_qualified_tau_candidates = np.where(
local_tau_table[gt_id, :] > tp)
gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[
0].shape[0]
det_matching_qualified_sigma_candidates = np.where(
local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
> tr)
det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[
0].shape[0]
det_matching_qualified_tau_candidates = np.where(
local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
tp)
det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[
0].shape[0]
if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
(det_matching_num_qualified_sigma_candidates == 1) and (
det_matching_num_qualified_tau_candidates == 1):
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, gt_id] = 1
matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
# recg start
print
"liushanshan one to one det_id = {}".format(matched_det_id)
print
"liushanshan one to one gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[
0]]
print
"liushanshan one to one gt_str_cur = {}".format(gt_str_cur)
print
"liushanshan one to one pred_str_cur = {}".format(pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
det_flag[0, matched_det_id] = 1
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
def one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for gt_id in range(num_gt):
# skip the following if the groundtruth was matched
if gt_flag[0, gt_id] > 0:
continue
non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
if num_non_zero_in_sigma >= k:
####search for all detections that overlaps with this groundtruth
qualified_tau_candidates = np.where((local_tau_table[
gt_id, :] >= tp) & (det_flag[0, :] == 0))
num_qualified_tau_candidates = qualified_tau_candidates[
0].shape[0]
if num_qualified_tau_candidates == 1:
if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp)
and
(local_sigma_table[gt_id, qualified_tau_candidates] >=
tr)):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
print
"liushanshan one to many det_id = {}".format(
qualified_tau_candidates)
print
"liushanshan one to many gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]]
print
"liushanshan one to many gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan one to many pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
>= tr):
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
print
"liushanshan one to many det_id = {}".format(
qualified_tau_candidates)
print
"liushanshan one to many gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]]
print
"liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
print
"liushanshan one to many pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
global_accumulative_recall = global_accumulative_recall + fsc_k
global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
local_accumulative_recall = local_accumulative_recall + fsc_k
local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
def many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for det_id in range(num_det):
# skip the following if the detection was matched
if det_flag[0, det_id] > 0:
continue
non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
if num_non_zero_in_tau >= k:
####search for all detections that overlaps with this groundtruth
qualified_sigma_candidates = np.where((
local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
num_qualified_sigma_candidates = qualified_sigma_candidates[
0].shape[0]
if num_qualified_sigma_candidates == 1:
if ((local_tau_table[qualified_sigma_candidates, det_id] >=
tp) and
(local_sigma_table[qualified_sigma_candidates, det_id]
>= tr)):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, qualified_sigma_candidates] = 1
det_flag[0, det_id] = 1
# recg start
print
"liushanshan many to one det_id = {}".format(det_id)
print
"liushanshan many to one gt_id = {}".format(
qualified_sigma_candidates)
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[
idx]
if ele_gt_id not in global_gt_str[idy]:
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
print
"liushanshan many to one gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan many to one pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
break
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
break
# recg end
elif (np.sum(local_tau_table[qualified_sigma_candidates,
det_id]) >= tp):
det_flag[0, det_id] = 1
gt_flag[0, qualified_sigma_candidates] = 1
# recg start
print
"liushanshan many to one det_id = {}".format(det_id)
print
"liushanshan many to one gt_id = {}".format(
qualified_sigma_candidates)
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
if not global_gt_str[idy].has_key(ele_gt_id):
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
print
"liushanshan many to one gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan many to one pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
break
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
break
else:
print
'no match'
# recg end
global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
global_accumulative_precision = global_accumulative_precision + fsc_k
local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
local_accumulative_precision = local_accumulative_precision + fsc_k
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
for idx in range(len(global_sigma)):
local_sigma_table = np.array(global_sigma[idx])
local_tau_table = global_tau[idx]
num_gt = local_sigma_table.shape[0]
num_det = local_sigma_table.shape[1]
total_num_gt = total_num_gt + num_gt
total_num_det = total_num_det + num_det
local_accumulative_recall = 0
local_accumulative_precision = 0
gt_flag = np.zeros((1, num_gt))
det_flag = np.zeros((1, num_det))
#######first check for one-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
#######then check for one-to-many case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
#######then check for many-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
try:
recall = global_accumulative_recall / total_num_gt
except ZeroDivisionError:
recall = 0
try:
precision = global_accumulative_precision / total_num_det
except ZeroDivisionError:
precision = 0
try:
f_score = 2 * precision * recall / (precision + recall)
except ZeroDivisionError:
f_score = 0
try:
seqerr = 1 - float(hit_str_count) / global_accumulative_recall
except ZeroDivisionError:
seqerr = 1
try:
recall_e2e = float(hit_str_count) / total_num_gt
except ZeroDivisionError:
recall_e2e = 0
try:
precision_e2e = float(hit_str_count) / total_num_det
except ZeroDivisionError:
precision_e2e = 0
try:
f_score_e2e = 2 * precision_e2e * recall_e2e / (
precision_e2e + recall_e2e)
except ZeroDivisionError:
f_score_e2e = 0
final = {
'total_num_gt': total_num_gt,
'total_num_det': total_num_det,
'global_accumulative_recall': global_accumulative_recall,
'hit_str_count': hit_str_count,
'recall': recall,
'precision': precision,
'f_score': f_score,
'seqerr': seqerr,
'recall_e2e': recall_e2e,
'precision_e2e': precision_e2e,
'f_score_e2e': f_score_e2e
}
return final
# a = [1526, 642, 1565, 629, 1579, 627, 1593, 625, 1607, 623, 1620, 622, 1634, 620, 1659, 620, 1654, 681, 1631, 680, 1618,
# 681, 1606, 681, 1594, 681, 1584, 682, 1573, 685, 1542, 694]
# gt_dict = [{'points': np.array(a).reshape(-1, 2), 'text': 'MILK'}]
# pred_dict = [{'points': np.array(a), 'text': 'ccc'},
# {'points': np.array(a), 'text': 'ccf'}]
# result = []
# for i in range(2):
# result.append(get_socre(gt_dict, pred_dict))
# print(111)
# a = combine_results(result)
# print(a)
import numpy as np
from shapely.geometry import Polygon
#import Polygon
"""
:param det_x: [1, N] Xs of detection's vertices
:param det_y: [1, N] Ys of detection's vertices
:param gt_x: [1, N] Xs of groundtruth's vertices
:param gt_y: [1, N] Ys of groundtruth's vertices
##############
All the calculation of 'AREA' in this script is handled by:
1) First generating a binary mask with the polygon area filled up with 1's
2) Summing up all the 1's
"""
def area(x, y):
polygon = Polygon(np.stack([x, y], axis=1))
return float(polygon.area)
def approx_area_of_intersection(det_x, det_y, gt_x, gt_y):
"""
This helper determine if both polygons are intersecting with each others with an approximation method.
Area of intersection represented by the minimum bounding rectangular [xmin, ymin, xmax, ymax]
"""
det_ymax = np.max(det_y)
det_xmax = np.max(det_x)
det_ymin = np.min(det_y)
det_xmin = np.min(det_x)
gt_ymax = np.max(gt_y)
gt_xmax = np.max(gt_x)
gt_ymin = np.min(gt_y)
gt_xmin = np.min(gt_x)
all_min_ymax = np.minimum(det_ymax, gt_ymax)
all_max_ymin = np.maximum(det_ymin, gt_ymin)
intersect_heights = np.maximum(0.0, (all_min_ymax - all_max_ymin))
all_min_xmax = np.minimum(det_xmax, gt_xmax)
all_max_xmin = np.maximum(det_xmin, gt_xmin)
intersect_widths = np.maximum(0.0, (all_min_xmax - all_max_xmin))
return intersect_heights * intersect_widths
def area_of_intersection(det_x, det_y, gt_x, gt_y):
p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
return float(p1.intersection(p2).area)
def area_of_union(det_x, det_y, gt_x, gt_y):
p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
return float(p1.union(p2).area)
def iou(det_x, det_y, gt_x, gt_y):
return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
area_of_union(det_x, det_y, gt_x, gt_y) + 1.0)
def iod(det_x, det_y, gt_x, gt_y):
"""
This helper determine the fraction of intersection area over detection area
"""
return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
area(det_x, det_y) + 1.0)
from os import listdir
import os, sys
from scipy import io
import numpy as np
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
from tqdm import tqdm
try: # python2
range = xrange
except Exception:
# python3
range = range
"""
Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('\n')'
"""
# if len(sys.argv) != 4:
# print('\n usage: test.py pred_dir gt_dir savefile')
# sys.exit()
global_tp = 0
global_fp = 0
global_fn = 0
tr = 0.7
tp = 0.6
fsc_k = 0.8
k = 2
def get_socre(gt_dict, pred_dict):
# allInputs = listdir(input_dir)
allInputs = 1
global_pred_str = []
global_gt_str = []
global_sigma = []
global_tau = []
def input_reading_mod(pred_dict, input):
"""This helper reads input from txt files"""
det = []
n = len(pred_dict)
for i in range(n):
points = pred_dict[i]['points']
text = pred_dict[i]['text']
# for i in range(len(points)):
point = ",".join(map(str, points.reshape(-1, )))
det.append([point, text])
return det
def gt_reading_mod(gt_dict, gt_id):
"""This helper reads groundtruths from mat files"""
# gt_id = gt_id.split('.')[0]
gt = []
n = len(gt_dict)
for i in range(n):
points = gt_dict[i]['points'].tolist()
h = len(points)
text = gt_dict[i]['text']
xx = [
np.array(
['x:'], dtype='<U2'), 0, np.array(
['y:'], dtype='<U2'), 0, np.array(
['#'], dtype='<U1'), np.array(
['#'], dtype='<U1')
]
t_x, t_y = [], []
for j in range(h):
t_x.append(points[j][0])
t_y.append(points[j][1])
xx[1] = np.array([t_x], dtype='int16')
xx[3] = np.array([t_y], dtype='int16')
if text != "":
xx[4] = np.array([text], dtype='U{}'.format(len(text)))
xx[5] = np.array(['c'], dtype='<U1')
gt.append(xx)
return gt
def detection_filtering(detections, groundtruths, threshold=0.5):
for gt_id, gt in enumerate(groundtruths):
print
"liushanshan gt[1] = {}".format(gt[1])
print
"liushanshan gt[2] = {}".format(gt[2])
print
"liushanshan gt[3] = {}".format(gt[3])
print
"liushanshan gt[4] = {}".format(gt[4])
print
"liushanshan gt[5] = {}".format(gt[5])
if (gt[5] == '#') and (gt[1].shape[1] > 1):
gt_x = list(map(int, np.squeeze(gt[1])))
gt_y = list(map(int, np.squeeze(gt[3])))
for det_id, detection in enumerate(detections):
detection_orig = detection
detection = [float(x) for x in detection[0].split(',')]
# detection = detection.split(',')
detection = list(map(int, detection))
det_x = detection[0::2]
det_y = detection[1::2]
det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
if det_gt_iou > threshold:
detections[det_id] = []
detections[:] = [item for item in detections if item != []]
return detections
def sigma_calculation(det_x, det_y, gt_x, gt_y):
"""
sigma = inter_area / gt_area
"""
# print(area_of_intersection(det_x, det_y, gt_x, gt_y))
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
area(gt_x, gt_y)), 2)
def tau_calculation(det_x, det_y, gt_x, gt_y):
"""
tau = inter_area / det_area
"""
# print "liushanshan det_x {}".format(det_x)
# print "liushanshan det_y {}".format(det_y)
# print "liushanshan area {}".format(area(det_x, det_y))
# print "liushanshan tau = {}".format(np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / area(det_x, det_y)), 2))
if area(det_x, det_y) == 0.0:
return 0
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
area(det_x, det_y)), 2)
##############################Initialization###################################
###############################################################################
single_data = {}
for input_id in range(allInputs):
if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
and (input_id != 'Deteval_result_non_curved.txt'):
print(input_id)
detections = input_reading_mod(pred_dict, input_id)
# print "liushanshan detections = {}".format(detections)
groundtruths = gt_reading_mod(gt_dict, input_id)
detections = detection_filtering(
detections,
groundtruths) # filters detections overlapping with DC area
dc_id = []
for i in range(len(groundtruths)):
if groundtruths[i][5] == '#':
dc_id.append(i)
cnt = 0
for a in dc_id:
num = a - cnt
del groundtruths[num]
cnt += 1
local_sigma_table = np.zeros((len(groundtruths), len(detections)))
local_tau_table = np.zeros((len(groundtruths), len(detections)))
local_pred_str = {}
local_gt_str = {}
for gt_id, gt in enumerate(groundtruths):
if len(detections) > 0:
for det_id, detection in enumerate(detections):
detection_orig = detection
detection = [float(x) for x in detection[0].split(',')]
detection = list(map(int, detection))
pred_seq_str = detection_orig[1].strip()
det_x = detection[0::2]
det_y = detection[1::2]
gt_x = list(map(int, np.squeeze(gt[1])))
gt_y = list(map(int, np.squeeze(gt[3])))
gt_seq_str = str(gt[4].tolist()[0])
local_sigma_table[gt_id, det_id] = sigma_calculation(
det_x, det_y, gt_x, gt_y)
local_tau_table[gt_id, det_id] = tau_calculation(
det_x, det_y, gt_x, gt_y)
local_pred_str[det_id] = pred_seq_str
local_gt_str[gt_id] = gt_seq_str
global_sigma.append(local_sigma_table)
global_tau.append(local_tau_table)
global_pred_str.append(local_pred_str)
global_gt_str.append(local_gt_str)
print
"liushanshan global_pred_str = {}".format(global_pred_str)
print
"liushanshan global_gt_str = {}".format(global_gt_str)
single_data['sigma'] = global_sigma
single_data['global_tau'] = global_tau
single_data['global_pred_str'] = global_pred_str
single_data['global_gt_str'] = global_gt_str
return single_data
def combine_results(all_data):
global_sigma, global_tau, global_pred_str, global_gt_str = [], [], [], []
for data in all_data:
global_sigma.append(data['sigma'])
global_tau.append(data['global_tau'])
global_pred_str.append(data['global_pred_str'])
global_gt_str.append(data['global_gt_str'])
global_accumulative_recall = 0
global_accumulative_precision = 0
total_num_gt = 0
total_num_det = 0
hit_str_count = 0
hit_count = 0
def one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for gt_id in range(num_gt):
gt_matching_qualified_sigma_candidates = np.where(
local_sigma_table[gt_id, :] > tr)
gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[
0].shape[0]
gt_matching_qualified_tau_candidates = np.where(
local_tau_table[gt_id, :] > tp)
gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[
0].shape[0]
det_matching_qualified_sigma_candidates = np.where(
local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
> tr)
det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[
0].shape[0]
det_matching_qualified_tau_candidates = np.where(
local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
tp)
det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[
0].shape[0]
if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
(det_matching_num_qualified_sigma_candidates == 1) and (
det_matching_num_qualified_tau_candidates == 1):
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, gt_id] = 1
matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
# recg start
print
"liushanshan one to one det_id = {}".format(matched_det_id)
print
"liushanshan one to one gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[
0]]
print
"liushanshan one to one gt_str_cur = {}".format(gt_str_cur)
print
"liushanshan one to one pred_str_cur = {}".format(pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
det_flag[0, matched_det_id] = 1
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
def one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for gt_id in range(num_gt):
# skip the following if the groundtruth was matched
if gt_flag[0, gt_id] > 0:
continue
non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
if num_non_zero_in_sigma >= k:
####search for all detections that overlaps with this groundtruth
qualified_tau_candidates = np.where((local_tau_table[
gt_id, :] >= tp) & (det_flag[0, :] == 0))
num_qualified_tau_candidates = qualified_tau_candidates[
0].shape[0]
if num_qualified_tau_candidates == 1:
if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp)
and
(local_sigma_table[gt_id, qualified_tau_candidates] >=
tr)):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
print
"liushanshan one to many det_id = {}".format(
qualified_tau_candidates)
print
"liushanshan one to many gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]]
print
"liushanshan one to many gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan one to many pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
>= tr):
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
print
"liushanshan one to many det_id = {}".format(
qualified_tau_candidates)
print
"liushanshan one to many gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]]
print
"liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
print
"liushanshan one to many pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
global_accumulative_recall = global_accumulative_recall + fsc_k
global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
local_accumulative_recall = local_accumulative_recall + fsc_k
local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
def many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for det_id in range(num_det):
# skip the following if the detection was matched
if det_flag[0, det_id] > 0:
continue
non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
if num_non_zero_in_tau >= k:
####search for all detections that overlaps with this groundtruth
qualified_sigma_candidates = np.where((
local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
num_qualified_sigma_candidates = qualified_sigma_candidates[
0].shape[0]
if num_qualified_sigma_candidates == 1:
if ((local_tau_table[qualified_sigma_candidates, det_id] >=
tp) and
(local_sigma_table[qualified_sigma_candidates, det_id]
>= tr)):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, qualified_sigma_candidates] = 1
det_flag[0, det_id] = 1
# recg start
print
"liushanshan many to one det_id = {}".format(det_id)
print
"liushanshan many to one gt_id = {}".format(
qualified_sigma_candidates)
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[
idx]
if not global_gt_str[idy].has_key(ele_gt_id):
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
print
"liushanshan many to one gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan many to one pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
break
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
break
# recg end
elif (np.sum(local_tau_table[qualified_sigma_candidates,
det_id]) >= tp):
det_flag[0, det_id] = 1
gt_flag[0, qualified_sigma_candidates] = 1
# recg start
print
"liushanshan many to one det_id = {}".format(det_id)
print
"liushanshan many to one gt_id = {}".format(
qualified_sigma_candidates)
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
if not global_gt_str[idy].has_key(ele_gt_id):
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
print
"liushanshan many to one gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan many to one pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
break
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
break
else:
print
'no match'
# recg end
global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
global_accumulative_precision = global_accumulative_precision + fsc_k
local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
local_accumulative_precision = local_accumulative_precision + fsc_k
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
for idx in range(len(global_sigma)):
# print(allInputs[idx])
local_sigma_table = np.array(global_sigma[idx])
local_tau_table = global_tau[idx]
num_gt = local_sigma_table.shape[0]
num_det = local_sigma_table.shape[1]
total_num_gt = total_num_gt + num_gt
total_num_det = total_num_det + num_det
local_accumulative_recall = 0
local_accumulative_precision = 0
gt_flag = np.zeros((1, num_gt))
det_flag = np.zeros((1, num_det))
#######first check for one-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
#######then check for one-to-many case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
#######then check for many-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
try:
recall = global_accumulative_recall / total_num_gt
except ZeroDivisionError:
recall = 0
try:
precision = global_accumulative_precision / total_num_det
except ZeroDivisionError:
precision = 0
try:
f_score = 2 * precision * recall / (precision + recall)
except ZeroDivisionError:
f_score = 0
try:
seqerr = 1 - float(hit_str_count) / global_accumulative_recall
except ZeroDivisionError:
seqerr = 1
try:
recall_e2e = float(hit_str_count) / total_num_gt
except ZeroDivisionError:
recall_e2e = 0
try:
precision_e2e = float(hit_str_count) / total_num_det
except ZeroDivisionError:
precision_e2e = 0
try:
f_score_e2e = 2 * precision_e2e * recall_e2e / (
precision_e2e + recall_e2e)
except ZeroDivisionError:
f_score_e2e = 0
final = {
'total_num_gt': total_num_gt,
'total_num_det': total_num_det,
'global_accumulative_recall': global_accumulative_recall,
'hit_str_count': hit_str_count,
'recall': recall,
'precision': precision,
'f_score': f_score,
'seqerr': seqerr,
'recall_e2e': recall_e2e,
'precision_e2e': precision_e2e,
'f_score_e2e': f_score_e2e
}
return final
# def combine_results(all_data):
# tr = 0.7
# tp = 0.6
# fsc_k = 0.8
# k = 2
# global_sigma = []
# global_tau = []
# global_pred_str = []
# global_gt_str = []
# for data in all_data:
# global_sigma.append(data['sigma'])
# global_tau.append(data['global_tau'])
# global_pred_str.append(data['global_pred_str'])
# global_gt_str.append(data['global_gt_str'])
#
# global_accumulative_recall = 0
# global_accumulative_precision = 0
# total_num_gt = 0
# total_num_det = 0
# hit_str_count = 0
# hit_count = 0
#
# def one_to_one(local_sigma_table, local_tau_table, local_accumulative_recall,
# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idy):
# hit_str_num = 0
# for gt_id in range(num_gt):
# gt_matching_qualified_sigma_candidates = np.where(local_sigma_table[gt_id, :] > tr)
# gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[0].shape[0]
# gt_matching_qualified_tau_candidates = np.where(local_tau_table[gt_id, :] > tp)
# gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[0].shape[0]
#
# det_matching_qualified_sigma_candidates = np.where(
# local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]] > tr)
# det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[0].shape[0]
# det_matching_qualified_tau_candidates = np.where(
# local_tau_table[:, gt_matching_qualified_tau_candidates[0]] > tp)
# det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[0].shape[0]
#
# if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
# (det_matching_num_qualified_sigma_candidates == 1) and (
# det_matching_num_qualified_tau_candidates == 1):
# global_accumulative_recall = global_accumulative_recall + 1.0
# global_accumulative_precision = global_accumulative_precision + 1.0
# local_accumulative_recall = local_accumulative_recall + 1.0
# local_accumulative_precision = local_accumulative_precision + 1.0
#
# gt_flag[0, gt_id] = 1
# matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
# # recg start
# print
# "liushanshan one to one det_id = {}".format(matched_det_id)
# print
# "liushanshan one to one gt_id = {}".format(gt_id)
# gt_str_cur = global_gt_str[idy][gt_id]
# pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[0]]
# print
# "liushanshan one to one gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan one to one pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# # recg end
# det_flag[0, matched_det_id] = 1
# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
#
# def one_to_many(local_sigma_table, local_tau_table, local_accumulative_recall,
# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idy):
# hit_str_num = 0
# for gt_id in range(num_gt):
# # skip the following if the groundtruth was matched
# if gt_flag[0, gt_id] > 0:
# continue
#
# non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
# num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
#
# if num_non_zero_in_sigma >= k:
# ####search for all detections that overlaps with this groundtruth
# qualified_tau_candidates = np.where((local_tau_table[gt_id, :] >= tp) & (det_flag[0, :] == 0))
# num_qualified_tau_candidates = qualified_tau_candidates[0].shape[0]
#
# if num_qualified_tau_candidates == 1:
# if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp) and (
# local_sigma_table[gt_id, qualified_tau_candidates] >= tr)):
# # became an one-to-one case
# global_accumulative_recall = global_accumulative_recall + 1.0
# global_accumulative_precision = global_accumulative_precision + 1.0
# local_accumulative_recall = local_accumulative_recall + 1.0
# local_accumulative_precision = local_accumulative_precision + 1.0
#
# gt_flag[0, gt_id] = 1
# det_flag[0, qualified_tau_candidates] = 1
# # recg start
# print
# "liushanshan one to many det_id = {}".format(qualified_tau_candidates)
# print
# "liushanshan one to many gt_id = {}".format(gt_id)
# gt_str_cur = global_gt_str[idy][gt_id]
# pred_str_cur = global_pred_str[idy][qualified_tau_candidates[0].tolist()[0]]
# print
# "liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan one to many pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# # recg end
# elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) >= tr):
# gt_flag[0, gt_id] = 1
# det_flag[0, qualified_tau_candidates] = 1
# # recg start
# print
# "liushanshan one to many det_id = {}".format(qualified_tau_candidates)
# print
# "liushanshan one to many gt_id = {}".format(gt_id)
# gt_str_cur = global_gt_str[idy][gt_id]
# pred_str_cur = global_pred_str[idy][qualified_tau_candidates[0].tolist()[0]]
# print
# "liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan one to many pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# # recg end
#
# global_accumulative_recall = global_accumulative_recall + fsc_k
# global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
#
# local_accumulative_recall = local_accumulative_recall + fsc_k
# local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
#
# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
#
# def many_to_one(local_sigma_table, local_tau_table, local_accumulative_recall,
# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idy):
# hit_str_num = 0
# for det_id in range(num_det):
# # skip the following if the detection was matched
# if det_flag[0, det_id] > 0:
# continue
#
# non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
# num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
#
# if num_non_zero_in_tau >= k:
# ####search for all detections that overlaps with this groundtruth
# qualified_sigma_candidates = np.where((local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
# num_qualified_sigma_candidates = qualified_sigma_candidates[0].shape[0]
#
# if num_qualified_sigma_candidates == 1:
# if ((local_tau_table[qualified_sigma_candidates, det_id] >= tp) and (
# local_sigma_table[qualified_sigma_candidates, det_id] >= tr)):
# # became an one-to-one case
# global_accumulative_recall = global_accumulative_recall + 1.0
# global_accumulative_precision = global_accumulative_precision + 1.0
# local_accumulative_recall = local_accumulative_recall + 1.0
# local_accumulative_precision = local_accumulative_precision + 1.0
#
# gt_flag[0, qualified_sigma_candidates] = 1
# det_flag[0, det_id] = 1
# # recg start
# print
# "liushanshan many to one det_id = {}".format(det_id)
# print
# "liushanshan many to one gt_id = {}".format(qualified_sigma_candidates)
# pred_str_cur = global_pred_str[idy][det_id]
# gt_len = len(qualified_sigma_candidates[0])
# for idx in range(gt_len):
# ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
# if ele_gt_id not in global_gt_str[idy]:
# continue
# gt_str_cur = global_gt_str[idy][ele_gt_id]
# print
# "liushanshan many to one gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan many to one pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# break
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# break
# # recg end
# elif (np.sum(local_tau_table[qualified_sigma_candidates, det_id]) >= tp):
# det_flag[0, det_id] = 1
# gt_flag[0, qualified_sigma_candidates] = 1
# # recg start
# print
# "liushanshan many to one det_id = {}".format(det_id)
# print
# "liushanshan many to one gt_id = {}".format(qualified_sigma_candidates)
# pred_str_cur = global_pred_str[idy][det_id]
# gt_len = len(qualified_sigma_candidates[0])
# for idx in range(gt_len):
# ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
# if not global_gt_str[idy].has_key(ele_gt_id):
# continue
# gt_str_cur = global_gt_str[idy][ele_gt_id]
# print
# "liushanshan many to one gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan many to one pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# break
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# break
# else:
# print
# 'no match'
# # recg end
#
# global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
# global_accumulative_precision = global_accumulative_precision + fsc_k
#
# local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
# local_accumulative_precision = local_accumulative_precision + fsc_k
# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
#
# for idx in range(len(global_sigma)):
# local_sigma_table = np.array(global_sigma[idx])
# local_tau_table = np.array(global_tau[idx])
#
# num_gt = local_sigma_table.shape[0]
# num_det = local_sigma_table.shape[1]
#
# total_num_gt = total_num_gt + num_gt
# total_num_det = total_num_det + num_det
#
# local_accumulative_recall = 0
# local_accumulative_precision = 0
# gt_flag = np.zeros((1, num_gt))
# det_flag = np.zeros((1, num_det))
#
# #######first check for one-to-one case##########
# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
# gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
# local_accumulative_recall, local_accumulative_precision,
# global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idx)
#
# hit_str_count += hit_str_num
# #######then check for one-to-many case##########
# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
# gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
# local_accumulative_recall, local_accumulative_precision,
# global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idx)
# hit_str_count += hit_str_num
# #######then check for many-to-one case##########
# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
# gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
# local_accumulative_recall, local_accumulative_precision,
# global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idx)
# try:
# recall = global_accumulative_recall / total_num_gt
# except ZeroDivisionError:
# recall = 0
#
# try:
# precision = global_accumulative_precision / total_num_det
# except ZeroDivisionError:
# precision = 0
#
# try:
# f_score = 2 * precision * recall / (precision + recall)
# except ZeroDivisionError:
# f_score = 0
#
# try:
# seqerr = 1 - float(hit_str_count) / global_accumulative_recall
# except ZeroDivisionError:
# seqerr = 1
#
# try:
# recall_e2e = float(hit_str_count) / total_num_gt
# except ZeroDivisionError:
# recall_e2e = 0
#
# try:
# precision_e2e = float(hit_str_count) / total_num_det
# except ZeroDivisionError:
# precision_e2e = 0
#
# try:
# f_score_e2e = 2 * precision_e2e * recall_e2e / (precision_e2e + recall_e2e)
# except ZeroDivisionError:
# f_score_e2e = 0
#
# final = {
# 'total_num_gt': total_num_gt,
# 'total_num_det': total_num_det,
# 'global_accumulative_recall': global_accumulative_recall,
# 'hit_str_count': hit_str_count,
# 'recall': recall,
# 'precision': precision,
# 'f_score': f_score,
# 'seqerr': seqerr,
# 'recall_e2e': recall_e2e,
# 'precision_e2e': precision_e2e,
# 'f_score_e2e': f_score_e2e
# }
# return final
a = [
1526, 642, 1565, 629, 1579, 627, 1593, 625, 1607, 623, 1620, 622, 1634, 620,
1659, 620, 1654, 681, 1631, 680, 1618, 681, 1606, 681, 1594, 681, 1584, 682,
1573, 685, 1542, 694
]
gt_dict = [{'points': np.array(a).reshape(-1, 2), 'text': 'MILK'}]
pred_dict = [{
'points': np.array(a),
'text': 'ccc'
}, {
'points': np.array(a),
'text': 'ccf'
}]
result = []
result.append(get_socre(gt_dict, gt_dict))
a = combine_results(result)
print(a)
"""Contains various CTC decoders."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import cv2
import time
import math
import numpy as np
from itertools import groupby
from ppocr.utils.e2e_utils.ski_thin import thin
def softmax(logits):
"""
logits: N x d
"""
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def get_keep_pos_idxs(labels, remove_blank=None):
"""
Remove duplicate and get pos idxs of keep items.
The value of keep_blank should be [None, 95].
"""
duplicate_len_list = []
keep_pos_idx_list = []
keep_char_idx_list = []
for k, v_ in groupby(labels):
current_len = len(list(v_))
if k != remove_blank:
current_idx = int(sum(duplicate_len_list) + current_len // 2)
keep_pos_idx_list.append(current_idx)
keep_char_idx_list.append(k)
duplicate_len_list.append(current_len)
return keep_char_idx_list, keep_pos_idx_list
def remove_blank(labels, blank=0):
new_labels = [x for x in labels if x != blank]
return new_labels
def insert_blank(labels, blank=0):
new_labels = [blank]
for l in labels:
new_labels += [l, blank]
return new_labels
def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
"""
CTC greedy (best path) decoder.
"""
raw_str = np.argmax(np.array(probs_seq), axis=1)
remove_blank_in_pos = None if keep_blank_in_idxs else blank
dedup_str, keep_idx_list = get_keep_pos_idxs(
raw_str, remove_blank=remove_blank_in_pos)
dst_str = remove_blank(dedup_str, blank=blank)
return dst_str, keep_idx_list
def instance_ctc_greedy_decoder(gather_info,
logits_map,
keep_blank_in_idxs=True):
"""
gather_info: [[x, y], [x, y] ...]
logits_map: H x W X (n_chars + 1)
"""
_, _, C = logits_map.shape
ys, xs = zip(*gather_info)
logits_seq = logits_map[list(ys), list(xs)] # n x 96
probs_seq = softmax(logits_seq)
dst_str, keep_idx_list = ctc_greedy_decoder(
probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs)
keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
return dst_str, keep_gather_list
def ctc_decoder_for_image(gather_info_list, logits_map,
keep_blank_in_idxs=True):
"""
CTC decoder using multiple processes.
"""
decoder_results = []
for gather_info in gather_info_list:
res = instance_ctc_greedy_decoder(
gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs)
decoder_results.append(res)
return decoder_results
def sort_with_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def sort_part_with_direction(pos_list, point_direction):
pos_list = np.array(pos_list).reshape(-1, 2)
point_direction = np.array(point_direction).reshape(-1, 2)
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
return sorted_list, sorted_direction
pos_list = np.array(pos_list).reshape(-1, 2)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
point_direction)
point_num = len(sorted_point)
if point_num >= 16:
middle_num = point_num // 2
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
first_part_point, first_point_direction)
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
last_part_point, last_point_direction)
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
return sorted_point, np.array(sorted_direction)
def add_id(pos_list, image_id=0):
"""
Add id for gather feature, for inference.
"""
new_list = []
for item in pos_list:
new_list.append((image_id, item[0], item[1]))
return new_list
def sort_and_expand_with_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
# expand along
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
right_dirction = point_direction[point_num - sub_direction_len:, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
left_start = np.array(sorted_list[0])
left_step = left_average_direction / (left_average_len + 1e-6)
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
right_average_len = np.linalg.norm(right_average_direction)
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
append_num = max(
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
left_list = []
right_list = []
for i in range(append_num):
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
'int32').tolist()
if ly < h and lx < w and (ly, lx) not in left_list:
left_list.append((ly, lx))
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
'int32').tolist()
if ry < h and rx < w and (ry, rx) not in right_list:
right_list.append((ry, rx))
all_list = left_list[::-1] + sorted_list + right_list
return all_list
def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
binary_tcl_map: h x w
"""
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
# expand along
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
right_dirction = point_direction[point_num - sub_direction_len:, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
left_start = np.array(sorted_list[0])
left_step = left_average_direction / (left_average_len + 1e-6)
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
right_average_len = np.linalg.norm(right_average_direction)
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
append_num = max(
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
max_append_num = 2 * append_num
left_list = []
right_list = []
for i in range(max_append_num):
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
'int32').tolist()
if ly < h and lx < w and (ly, lx) not in left_list:
if binary_tcl_map[ly, lx] > 0.5:
left_list.append((ly, lx))
else:
break
for i in range(max_append_num):
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
'int32').tolist()
if ry < h and rx < w and (ry, rx) not in right_list:
if binary_tcl_map[ry, rx] > 0.5:
right_list.append((ry, rx))
else:
break
all_list = left_list[::-1] + sorted_list + right_list
return all_list
def generate_pivot_list_curved(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_expand=True,
is_backbone=False,
image_id=0):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map = (p_score > score_thresh) * 1.0
skeleton_map = thin(p_tcl_map)
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
center_pos_yxs = []
end_points_yxs = []
instance_center_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
### FIX-ME, eliminate outlier
if len(pos_list) < 3:
continue
if is_expand:
pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, p_tcl_map)
else:
pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
all_pos_yxs.append(pos_list_sorted)
# use decoder to filter backgroud points.
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
for decoded_str, keep_yxs_list in decode_res:
if is_backbone:
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
instance_center_pos_yxs.append(keep_yxs_list_with_id)
else:
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
center_pos_yxs.extend(keep_yxs_list)
if is_backbone:
return instance_center_pos_yxs
else:
return center_pos_yxs, end_points_yxs
def generate_pivot_list_horizontal(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
image_id=0):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map_bi = (p_score > score_thresh) * 1.0
instance_count, instance_label_map = cv2.connectedComponents(
p_tcl_map_bi.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
center_pos_yxs = []
end_points_yxs = []
instance_center_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
### FIX-ME, eliminate outlier
if len(pos_list) < 5:
continue
# add rule here
main_direction = extract_main_direction(pos_list,
f_direction) # y x
reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
is_h_angle = abs(np.sum(
main_direction * reference_directin)) < math.cos(math.pi / 180 *
70)
point_yxs = np.array(pos_list)
max_y, max_x = np.max(point_yxs, axis=0)
min_y, min_x = np.min(point_yxs, axis=0)
is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x)
pos_list_final = []
if is_h_len:
xs = np.unique(xs)
for x in xs:
ys = instance_label_map[:, x].copy().reshape((-1, ))
y = int(np.where(ys == instance_id)[0].mean())
pos_list_final.append((y, x))
else:
ys = np.unique(ys)
for y in ys:
xs = instance_label_map[y, :].copy().reshape((-1, ))
x = int(np.where(xs == instance_id)[0].mean())
pos_list_final.append((y, x))
pos_list_sorted, _ = sort_with_direction(pos_list_final,
f_direction)
all_pos_yxs.append(pos_list_sorted)
# use decoder to filter backgroud points.
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
for decoded_str, keep_yxs_list in decode_res:
if is_backbone:
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
instance_center_pos_yxs.append(keep_yxs_list_with_id)
else:
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
center_pos_yxs.extend(keep_yxs_list)
if is_backbone:
return instance_center_pos_yxs
else:
return center_pos_yxs, end_points_yxs
def generate_pivot_list(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
is_curved=True,
image_id=0):
"""
Warp all the function together.
"""
if is_curved:
return generate_pivot_list_curved(
p_score,
p_char_maps,
f_direction,
score_thresh=score_thresh,
is_expand=True,
is_backbone=is_backbone,
image_id=image_id)
else:
return generate_pivot_list_horizontal(
p_score,
p_char_maps,
f_direction,
score_thresh=score_thresh,
is_backbone=is_backbone,
image_id=image_id)
# for refine module
def extract_main_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
pos_list = np.array(pos_list)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
average_direction = average_direction / (
np.linalg.norm(average_direction) + 1e-6)
return average_direction
def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
"""
pos_list_full = np.array(pos_list).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
return sorted_list
def sort_by_direction_with_image_id(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def sort_part_with_direction(pos_list_full, point_direction):
pos_list_full = np.array(pos_list_full).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = np.array(point_direction).reshape(-1, 2)
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
return sorted_list, sorted_direction
pos_list = np.array(pos_list).reshape(-1, 3)
point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
point_direction)
point_num = len(sorted_point)
if point_num >= 16:
middle_num = point_num // 2
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
first_part_point, first_point_direction)
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
last_part_point, last_point_direction)
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
return sorted_point
def generate_pivot_list_tt_inference(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
is_curved=True,
image_id=0):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map = (p_score > score_thresh) * 1.0
skeleton_map = thin(p_tcl_map)
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
### FIX-ME, eliminate outlier
if len(pos_list) < 3:
continue
pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, p_tcl_map)
# pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
all_pos_yxs.append(pos_list_sorted_with_id)
return all_pos_yxs
if __name__ == '__main__':
np.random.seed(0)
import time
logits_map = np.random.random([10, 20, 33])
# a list of [x, y]
instance_gather_info_1 = [(2, 3), (2, 4), (3, 5)]
instance_gather_info_2 = [(15, 6), (15, 7), (18, 8)]
instance_gather_info_3 = [(8, 8), (8, 8), (8, 8)]
gather_info_list = [
instance_gather_info_1, instance_gather_info_2, instance_gather_info_3
]
time0 = time.time()
res = ctc_decoder_for_image(
gather_info_list, logits_map, keep_blank_in_idxs=True)
print(res)
print('cost {}'.format(time.time() - time0))
print('--' * 20)
"""
Algorithms for computing the skeleton of a binary image
"""
import numpy as np
from scipy import ndimage as ndi
G123_LUT = np.array(
[
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0,
0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0,
1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0,
1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0
],
dtype=np.bool)
G123P_LUT = np.array(
[
0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1,
0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
dtype=np.bool)
def thin(image, max_iter=None):
"""
Perform morphological thinning of a binary image.
Parameters
----------
image : binary (M, N) ndarray
The image to be thinned.
max_iter : int, number of iterations, optional
Regardless of the value of this parameter, the thinned image
is returned immediately if an iteration produces no change.
If this parameter is specified it thus sets an upper bound on
the number of iterations performed.
Returns
-------
out : ndarray of bool
Thinned image.
See also
--------
skeletonize, medial_axis
Notes
-----
This algorithm [1]_ works by making multiple passes over the image,
removing pixels matching a set of criteria designed to thin
connected regions while preserving eight-connected components and
2 x 2 squares [2]_. In each of the two sub-iterations the algorithm
correlates the intermediate skeleton image with a neighborhood mask,
then looks up each neighborhood in a lookup table indicating whether
the central pixel should be deleted in that sub-iteration.
References
----------
.. [1] Z. Guo and R. W. Hall, "Parallel thinning with
two-subiteration algorithms," Comm. ACM, vol. 32, no. 3,
pp. 359-373, 1989. :DOI:`10.1145/62065.62074`
.. [2] Lam, L., Seong-Whan Lee, and Ching Y. Suen, "Thinning
Methodologies-A Comprehensive Survey," IEEE Transactions on
Pattern Analysis and Machine Intelligence, Vol 14, No. 9,
p. 879, 1992. :DOI:`10.1109/34.161346`
Examples
--------
>>> square = np.zeros((7, 7), dtype=np.uint8)
>>> square[1:-1, 2:-2] = 1
>>> square[0, 1] = 1
>>> square
array([[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0]], dtype=uint8)
>>> skel = thin(square)
>>> skel.astype(np.uint8)
array([[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0]], dtype=uint8)
"""
# convert image to uint8 with values in {0, 1}
skel = np.asanyarray(image, dtype=bool).astype(np.uint8)
# neighborhood mask
mask = np.array([[8, 4, 2], [16, 0, 1], [32, 64, 128]], dtype=np.uint8)
# iterate until convergence, up to the iteration limit
max_iter = max_iter or np.inf
n_iter = 0
n_pts_old, n_pts_new = np.inf, np.sum(skel)
while n_pts_old != n_pts_new and n_iter < max_iter:
n_pts_old = n_pts_new
# perform the two "subiterations" described in the paper
for lut in [G123_LUT, G123P_LUT]:
# correlate image with neighborhood mask
N = ndi.correlate(skel, mask, mode='constant')
# take deletion decision from this subiteration's LUT
D = np.take(lut, N)
# perform deletion
skel[D] = 0
n_pts_new = np.sum(skel) # count points after thinning
n_iter += 1
return skel.astype(np.bool)
import os
import numpy as np
import cv2
import time
def visualize_e2e_result(im_fn, poly_list, seq_strs, src_im):
"""
"""
result_path = './out'
im_basename = os.path.basename(im_fn)
im_prefix = im_basename[:im_basename.rfind('.')]
vis_det_img = src_im.copy()
valid_set = 'partvgg'
gt_dir = "/Users/hongyongjie/Downloads/part_vgg_synth/train"
text_path = os.path.join(gt_dir, im_prefix + '.txt')
fid = open(text_path, 'r')
lines = [line.strip() for line in fid.readlines()]
for line in lines:
if valid_set == 'partvgg':
tokens = line.strip().split('\t')[0].split(',')
# tokens = line.strip().split(',')
coords = tokens[:]
coords = list(map(float, coords))
gt_poly = np.array(coords).reshape(1, 4, 2)
elif valid_set == 'totaltext':
tokens = line.strip().split('\t')[0].split(',')
coords = tokens[:]
coords_len = len(coords) / 2
coords = list(map(float, coords))
gt_poly = np.array(coords).reshape(1, coords_len, 2)
cv2.polylines(
vis_det_img,
np.array(gt_poly).astype(np.int32),
isClosed=True,
color=(255, 0, 0),
thickness=2)
for detected_poly, recognized_str in zip(poly_list, seq_strs):
cv2.polylines(
vis_det_img,
np.array(detected_poly[np.newaxis, ...]).astype(np.int32),
isClosed=True,
color=(0, 0, 255),
thickness=2)
cv2.putText(
vis_det_img,
recognized_str,
org=(int(detected_poly[0, 0]), int(detected_poly[0, 1])),
fontFace=cv2.FONT_HERSHEY_COMPLEX,
fontScale=0.7,
color=(0, 255, 0),
thickness=1)
if not os.path.exists(result_path):
os.makedirs(result_path)
cv2.imwrite("{}/{}_detection.jpg".format(result_path, im_prefix),
vis_det_img)
def visualization_output(src_image,
f_tcl,
f_chars,
output_dir,
image_prefix=None):
"""
"""
# restore BGR image, CHW -> HWC
im_mean = [0.485, 0.456, 0.406]
im_std = [0.229, 0.224, 0.225]
im_mean = np.array(im_mean).reshape((3, 1, 1))
im_std = np.array(im_std).reshape((3, 1, 1))
src_image *= im_std
src_image += im_mean
src_image = src_image.transpose([1, 2, 0])
src_image = src_image[:, :, ::-1] * 255 # BGR -> RGB
H, W, _ = src_image.shape
file_prefix = image_prefix if image_prefix is not None else str(
int(time.time() * 1000))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# visualization f_tcl
tcl_file_name = os.path.join(output_dir, file_prefix + '_0_tcl.jpg')
vis_tcl_img = src_image.copy()
f_tcl_resized = cv2.resize(f_tcl, dsize=(W, H))
vis_tcl_img[:, :, 1] = f_tcl_resized * 255
cv2.imwrite(tcl_file_name, vis_tcl_img)
# visualization char maps
vis_char_img = src_image.copy()
# CHW -> HWC
char_file_name = os.path.join(output_dir, file_prefix + '_1_chars.jpg')
f_chars = np.argmax(f_chars, axis=2)[:, :, np.newaxis].astype('float32')
f_chars[f_chars < 95] = 1.0
f_chars[f_chars == 95] = 0.0
f_chars_resized = cv2.resize(f_chars, dsize=(W, H))
vis_char_img[:, :, 1] = f_chars_resized * 255
cv2.imwrite(char_file_name, vis_char_img)
def visualize_point_result(im_fn, point_list, point_pair_list, src_im, gt_dir,
result_path):
"""
"""
im_basename = os.path.basename(im_fn)
im_prefix = im_basename[:im_basename.rfind('.')]
vis_det_img = src_im.copy()
# draw gt bbox on the image.
text_path = os.path.join(gt_dir, im_prefix + '.txt')
fid = open(text_path, 'r')
lines = [line.strip() for line in fid.readlines()]
for line in lines:
tokens = line.strip().split('\t')
coords = tokens[0].split(',')
coords_len = len(coords)
coords = list(map(float, coords))
gt_poly = np.array(coords).reshape(1, coords_len / 2, 2)
cv2.polylines(
vis_det_img,
np.array(gt_poly).astype(np.int32),
isClosed=True,
color=(255, 255, 255),
thickness=1)
for point, point_pair in zip(point_list, point_pair_list):
cv2.line(
vis_det_img,
tuple(point_pair[0]),
tuple(point_pair[1]), (0, 255, 255),
thickness=1)
cv2.circle(vis_det_img, tuple(point), 2, (0, 0, 255))
cv2.circle(vis_det_img, tuple(point_pair[0]), 2, (255, 0, 0))
cv2.circle(vis_det_img, tuple(point_pair[1]), 2, (0, 255, 0))
if not os.path.exists(result_path):
os.makedirs(result_path)
cv2.imwrite("{}/{}_border_points.jpg".format(result_path, im_prefix),
vis_det_img)
def resize_image(im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image_min(im, max_side_len=512):
"""
"""
print('--> Using resize_image_min')
h, w, _ = im.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h < resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image_for_totaltext(im, max_side_len=512):
"""
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
# Fix the longer side
# if resize_h > resize_w:
# ratio = float(max_side_len) / resize_h
# else:
# ratio = float(max_side_len) / resize_w
###
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def point_pair2poly(point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
pair_length_list = []
for point_pair in point_pair_list:
pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
pair_length_list.append(pair_length)
pair_length_list = np.array(pair_length_list)
pair_info = (pair_length_list.max(), pair_length_list.min(),
pair_length_list.mean())
# constract poly
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
for idx, point_pair in enumerate(point_pair_list):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2), pair_info
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + \
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
poly[point_num // 2] = right_quad_expand[2]
return poly
def norm2(x, axis=None):
if axis:
return np.sqrt(np.sum(x**2, axis=axis))
return np.sqrt(np.sum(x**2))
def cos(p1, p2):
return (p1 * p2).sum() / (norm2(p1) * norm2(p2))
def generate_direction_info(image_fn,
H,
W,
ratio_h,
ratio_w,
max_length=640,
out_scale=4,
gt_dir=None):
"""
"""
im_basename = os.path.basename(image_fn)
im_prefix = im_basename[:im_basename.rfind('.')]
instance_direction_map = np.zeros(shape=[H // out_scale, W // out_scale, 3])
if gt_dir is None:
gt_dir = '/home/vis/huangzuming/data/SYNTH_DATA/part_vgg_synth_icdar/processed/val/poly'
# get gt label map
text_path = os.path.join(gt_dir, im_prefix + '.txt')
fid = open(text_path, 'r')
lines = [line.strip() for line in fid.readlines()]
for label_idx, line in enumerate(lines, start=1):
coords, txt = line.strip().split('\t')
if txt == '###':
continue
tokens = coords.strip().split(',')
coords = list(map(float, tokens))
poly = np.array(coords).reshape(4, 2) * np.array(
[ratio_w, ratio_h]).reshape(1, 2) / out_scale
mid_idx = poly.shape[0] // 2
direct_vector = (
(poly[mid_idx] + poly[mid_idx - 1]) - (poly[0] + poly[-1])) / 2.0
direct_vector /= len(txt)
# l2_distance = norm2(direct_vector)
# avg_char_distance = l2_distance / len(txt)
avg_char_distance = 1.0
direct_label = (direct_vector[0], direct_vector[1], avg_char_distance)
cv2.fillPoly(instance_direction_map,
poly.round().astype(np.int32)[np.newaxis, :, :],
direct_label)
instance_direction_map = instance_direction_map.transpose([2, 0, 1])
return instance_direction_map[:2, ...]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import json
import paddle
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
def draw_e2e_res(dt_boxes, strs, config, img, img_name):
if len(dt_boxes) > 0:
src_im = img
for box, str in zip(dt_boxes, strs):
box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
cv2.putText(src_im, str, org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=0.7, color=(0, 255, 0), thickness=1)
save_det_path = os.path.dirname(config['Global'][
'save_res_path']) + "/e2e_results/"
if not os.path.exists(save_det_path):
os.makedirs(save_det_path)
save_path = os.path.join(save_det_path, os.path.basename(img_name))
cv2.imwrite(save_path, src_im)
logger.info("The e2e Image saved in {}".format(save_path))
def main():
global_config = config['Global']
# build model
model = build_model(config['Architecture'])
init_model(config, model, logger)
# build post process
post_process_class = build_post_process(config['PostProcess'])
# create data ops
transforms = []
for op in config['Eval']['dataset']['transforms']:
op_name = list(op)[0]
if 'Label' in op_name:
continue
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['image', 'shape']
transforms.append(op)
ops = create_operators(transforms, global_config)
save_res_path = config['Global']['save_res_path']
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
model.eval()
with open(save_res_path, "wb") as fout:
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file))
with open(file, 'rb') as f:
img = f.read()
data = {'image': img}
batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0)
shape_list = np.expand_dims(batch[1], axis=0)
images = paddle.to_tensor(images)
preds = model(images)
post_result = post_process_class(preds, shape_list)
points, strs = post_result['points'], post_result['strs']
# write resule
dt_boxes_json = []
for poly, str in zip(points, strs):
tmp_json = {"transcription": str}
tmp_json['points'] = poly.tolist()
dt_boxes_json.append(tmp_json)
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
fout.write(otstr.encode())
src_img = cv2.imread(file)
draw_e2e_res(points, strs, config, src_img, file)
logger.info("success!")
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess()
main()
\ No newline at end of file
......@@ -44,6 +44,7 @@ class ArgsParser(ArgumentParser):
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
args.config = '/Users/hongyongjie/project/PaddleOCR/configs/e2e/e2e_r50_vd_pg.yml'
assert args.config is not None, \
"Please specify --config=configure_file_path."
args.opt = self._parse_opt(args.opt)
......@@ -374,7 +375,8 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS'
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PG'
]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册