提交 a0d1f923 编写于 作者: J Jethong

add different post process

上级 0cd48c35
...@@ -11,7 +11,7 @@ Global: ...@@ -11,7 +11,7 @@ Global:
# from static branch, load_static_weights must be set as True. # from static branch, load_static_weights must be set as True.
# 2. If you want to finetune the pretrained models we provide in the docs, # 2. If you want to finetune the pretrained models we provide in the docs,
# you should set load_static_weights as False. # you should set load_static_weights as False.
load_static_weights: True load_static_weights: False
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: pretrained_model:
checkpoints: checkpoints:
...@@ -94,7 +94,7 @@ Eval: ...@@ -94,7 +94,7 @@ Eval:
label_file_list: [./train_data/total_text/test/] label_file_list: [./train_data/total_text/test/]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: RGB
channel_first: False channel_first: False
- E2ELabelEncode: - E2ELabelEncode:
- E2EResizeForTest: - E2EResizeForTest:
......
...@@ -200,16 +200,18 @@ class E2ELabelEncode(BaseRecLabelEncode): ...@@ -200,16 +200,18 @@ class E2ELabelEncode(BaseRecLabelEncode):
self.pad_num = len(self.dict) # the length to pad self.pad_num = len(self.dict) # the length to pad
def __call__(self, data): def __call__(self, data):
text_label_index_list, temp_text = [], []
texts = data['strs'] texts = data['strs']
temp_texts = []
for text in texts: for text in texts:
text = text.lower() text = text.lower()
text = self.encode(text) temp_text = []
if text is None: for c_ in text:
return None if c_ in self.dict:
text = text + [self.pad_num] * (self.max_text_len - len(text)) temp_text.append(self.dict[c_])
temp_texts.append(text) temp_text = temp_text + [self.pad_num] * (self.max_text_len -
data['strs'] = np.array(temp_texts) len(temp_text))
text_label_index_list.append(temp_text)
data['strs'] = np.array(text_label_index_list)
return data return data
......
...@@ -24,6 +24,7 @@ class PGDataSet(Dataset): ...@@ -24,6 +24,7 @@ class PGDataSet(Dataset):
self.logger = logger self.logger = logger
self.seed = seed self.seed = seed
self.mode = mode
global_config = config['Global'] global_config = config['Global']
dataset_config = config[mode]['dataset'] dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader'] loader_config = config[mode]['loader']
...@@ -62,10 +63,13 @@ class PGDataSet(Dataset): ...@@ -62,10 +63,13 @@ class PGDataSet(Dataset):
with open(poly_txt_path) as f: with open(poly_txt_path) as f:
for line in f.readlines(): for line in f.readlines():
poly_str, txt = line.strip().split('\t') poly_str, txt = line.strip().split('\t')
poly = map(float, poly_str.split(',')) poly = list(map(float, poly_str.split(',')))
if self.mode.lower() == "eval":
while len(poly) < 100:
poly.append(-1)
text_polys.append( text_polys.append(
np.array( np.array(
list(poly), dtype=np.float32).reshape(-1, 2)) poly, dtype=np.float32).reshape(-1, 2))
txts.append(txt) txts.append(txt)
txt_tags.append(txt == '###') txt_tags.append(txt == '###')
...@@ -135,6 +139,10 @@ class PGDataSet(Dataset): ...@@ -135,6 +139,10 @@ class PGDataSet(Dataset):
try: try:
if self.data_format == 'icdar': if self.data_format == 'icdar':
im_path = os.path.join(data_path, 'rgb', data_line) im_path = os.path.join(data_path, 'rgb', data_line)
if self.mode.lower() == "eval":
poly_path = os.path.join(data_path, 'poly_gt',
data_line.split('.')[0] + '.txt')
else:
poly_path = os.path.join(data_path, 'poly', poly_path = os.path.join(data_path, 'poly',
data_line.split('.')[0] + '.txt') data_line.split('.')[0] + '.txt')
text_polys, text_tags, text_strs = self.extract_polys(poly_path) text_polys, text_tags, text_strs = self.extract_polys(poly_path)
......
...@@ -33,10 +33,20 @@ class E2EMetric(object): ...@@ -33,10 +33,20 @@ class E2EMetric(object):
self.reset() self.reset()
def __call__(self, preds, batch, **kwargs): def __call__(self, preds, batch, **kwargs):
gt_polyons_batch = batch[2] temp_gt_polyons_batch = batch[2]
temp_gt_strs_batch = batch[3] temp_gt_strs_batch = batch[3]
ignore_tags_batch = batch[4] ignore_tags_batch = batch[4]
gt_polyons_batch = []
gt_strs_batch = [] gt_strs_batch = []
temp_gt_polyons_batch = temp_gt_polyons_batch[0].tolist()
for temp_list in temp_gt_polyons_batch:
t = []
for index in temp_list:
if index[0] != -1 and index[1] != -1:
t.append(index)
gt_polyons_batch.append(t)
temp_gt_strs_batch = temp_gt_strs_batch[0].tolist() temp_gt_strs_batch = temp_gt_strs_batch[0].tolist()
for temp_list in temp_gt_strs_batch: for temp_list in temp_gt_strs_batch:
t = "" t = ""
...@@ -46,7 +56,7 @@ class E2EMetric(object): ...@@ -46,7 +56,7 @@ class E2EMetric(object):
gt_strs_batch.append(t) gt_strs_batch.append(t)
for pred, gt_polyons, gt_strs, ignore_tags in zip( for pred, gt_polyons, gt_strs, ignore_tags in zip(
[preds], gt_polyons_batch, [gt_strs_batch], ignore_tags_batch): [preds], [gt_polyons_batch], [gt_strs_batch], ignore_tags_batch):
# prepare gt # prepare gt
gt_info_list = [{ gt_info_list = [{
'points': gt_polyon, 'points': gt_polyon,
......
...@@ -23,7 +23,8 @@ __dir__ = os.path.dirname(__file__) ...@@ -23,7 +23,8 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..')) sys.path.append(os.path.join(__dir__, '..'))
from ppocr.utils.e2e_utils.extract_textpoint import get_dict, generate_pivot_list, restore_poly from ppocr.utils.e2e_utils.extract_textpoint import *
from ppocr.utils.e2e_utils.visual import *
import paddle import paddle
...@@ -37,6 +38,11 @@ class PGPostProcess(object): ...@@ -37,6 +38,11 @@ class PGPostProcess(object):
self.valid_set = valid_set self.valid_set = valid_set
self.score_thresh = score_thresh self.score_thresh = score_thresh
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
if sys.version_info.major == 3 and sys.version_info.minor == 5:
self.is_python35 = True
def __call__(self, outs_dict, shape_list): def __call__(self, outs_dict, shape_list):
p_score = outs_dict['f_score'] p_score = outs_dict['f_score']
p_border = outs_dict['f_border'] p_border = outs_dict['f_border']
...@@ -52,17 +58,96 @@ class PGPostProcess(object): ...@@ -52,17 +58,96 @@ class PGPostProcess(object):
p_border = p_border[0] p_border = p_border[0]
p_direction = p_direction[0] p_direction = p_direction[0]
p_char = p_char[0] p_char = p_char[0]
src_h, src_w, ratio_h, ratio_w = shape_list[0] src_h, src_w, ratio_h, ratio_w = shape_list[0]
instance_yxs_list, seq_strs = generate_pivot_list( is_curved = self.valid_set == "totaltext"
instance_yxs_list = generate_pivot_list(
p_score, p_score,
p_char, p_char,
p_direction, p_direction,
self.Lexicon_Table, score_thresh=self.score_thresh,
score_thresh=self.score_thresh) is_backbone=True,
poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs, is_curved=is_curved)
p_border, ratio_w, ratio_h, p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0))
src_w, src_h, self.valid_set) char_seq_idx_set = []
for i in range(len(instance_yxs_list)):
gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
feature_seq = paddle.gather_nd(f_char_map, gather_info_lod)
feature_seq = np.expand_dims(feature_seq.numpy(), axis=0)
feature_len = [len(feature_seq[0])]
featyre_seq = paddle.to_tensor(feature_seq)
feature_len = np.array([feature_len]).astype(np.int64)
length = paddle.to_tensor(feature_len)
seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
input=featyre_seq, blank=36, input_length=length)
seq_pred_str = seq_pred[0].numpy().tolist()[0]
seq_len = seq_pred[1].numpy()[0][0]
temp_t = []
for c in seq_pred_str[:seq_len]:
temp_t.append(c)
char_seq_idx_set.append(temp_t)
seq_strs = []
for char_idx_set in char_seq_idx_set:
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
seq_strs.append(pr_str)
poly_list = []
keep_str_list = []
all_point_list = []
all_point_pair_list = []
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
if len(yx_center_line) == 1:
yx_center_line.append(yx_center_line[-1])
offset_expand = 1.0
if self.valid_set == 'totaltext':
offset_expand = 1.2
point_pair_list = []
for batch_id, y, x in yx_center_line:
offset = p_border[:, y, x].reshape(2, 2)
if offset_expand != 1.0:
offset_length = np.linalg.norm(
offset, axis=1, keepdims=True)
expand_length = np.clip(
offset_length * (offset_expand - 1),
a_min=0.5,
a_max=3.0)
offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
all_point_list.append([
int(round(x * 4.0 / ratio_w)),
int(round(y * 4.0 / ratio_h))
])
all_point_pair_list.append(point_pair.round().astype(np.int32)
.tolist())
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
detected_poly = expand_poly_along_width(
detected_poly, shrink_ratio_of_width=0.2)
detected_poly[:, 0] = np.clip(
detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(
detected_poly[:, 1], a_min=0, a_max=src_h)
if len(keep_str) < 2:
continue
keep_str_list.append(keep_str)
if self.valid_set == 'partvgg':
middle_point = len(detected_poly) // 2
detected_poly = detected_poly[
[0, middle_point - 1, middle_point, -1], :]
poly_list.append(detected_poly)
elif self.valid_set == 'totaltext':
poly_list.append(detected_poly)
else:
print('--> Not supported format.')
exit(-1)
data = { data = {
'points': poly_list, 'points': poly_list,
'strs': keep_str_list, 'strs': keep_str_list,
......
...@@ -35,7 +35,7 @@ def get_socre(gt_dict, pred_dict): ...@@ -35,7 +35,7 @@ def get_socre(gt_dict, pred_dict):
gt = [] gt = []
n = len(gt_dict) n = len(gt_dict)
for i in range(n): for i in range(n):
points = gt_dict[i]['points'].tolist() points = gt_dict[i]['points']
h = len(points) h = len(points)
text = gt_dict[i]['text'] text = gt_dict[i]['text']
xx = [ xx = [
...@@ -51,7 +51,7 @@ def get_socre(gt_dict, pred_dict): ...@@ -51,7 +51,7 @@ def get_socre(gt_dict, pred_dict):
t_y.append(points[j][1]) t_y.append(points[j][1])
xx[1] = np.array([t_x], dtype='int16') xx[1] = np.array([t_x], dtype='int16')
xx[3] = np.array([t_y], dtype='int16') xx[3] = np.array([t_y], dtype='int16')
if text != "": if text != "" and "#" not in text:
xx[4] = np.array([text], dtype='U{}'.format(len(text))) xx[4] = np.array([text], dtype='U{}'.format(len(text)))
xx[5] = np.array(['c'], dtype='<U1') xx[5] = np.array(['c'], dtype='<U1')
gt.append(xx) gt.append(xx)
...@@ -89,17 +89,10 @@ def get_socre(gt_dict, pred_dict): ...@@ -89,17 +89,10 @@ def get_socre(gt_dict, pred_dict):
area(det_x, det_y)), 2) area(det_x, det_y)), 2)
##############################Initialization################################### ##############################Initialization###################################
global_tp = 0 # global_sigma = []
global_fp = 0 # global_tau = []
global_fn = 0 # global_pred_str = []
global_sigma = [] # global_gt_str = []
global_tau = []
tr = 0.7
tp = 0.6
fsc_k = 0.8
k = 2
global_pred_str = []
global_gt_str = []
############################################################################### ###############################################################################
for input_id in range(allInputs): for input_id in range(allInputs):
...@@ -147,281 +140,16 @@ def get_socre(gt_dict, pred_dict): ...@@ -147,281 +140,16 @@ def get_socre(gt_dict, pred_dict):
local_pred_str[det_id] = pred_seq_str local_pred_str[det_id] = pred_seq_str
local_gt_str[gt_id] = gt_seq_str local_gt_str[gt_id] = gt_seq_str
global_sigma.append(local_sigma_table) global_sigma = local_sigma_table
global_tau.append(local_tau_table) global_tau = local_tau_table
global_pred_str.append(local_pred_str) global_pred_str = local_pred_str
global_gt_str.append(local_gt_str) global_gt_str = local_gt_str
global_accumulative_recall = 0
global_accumulative_precision = 0
total_num_gt = 0
total_num_det = 0
hit_str_count = 0
hit_count = 0
def one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for gt_id in range(num_gt):
gt_matching_qualified_sigma_candidates = np.where(
local_sigma_table[gt_id, :] > tr)
gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[
0].shape[0]
gt_matching_qualified_tau_candidates = np.where(
local_tau_table[gt_id, :] > tp)
gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[
0].shape[0]
det_matching_qualified_sigma_candidates = np.where(
local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
> tr)
det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[
0].shape[0]
det_matching_qualified_tau_candidates = np.where(
local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
tp)
det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[
0].shape[0]
if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
(det_matching_num_qualified_sigma_candidates == 1) and (
det_matching_num_qualified_tau_candidates == 1):
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, gt_id] = 1
matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
# recg start
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[
0]]
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
det_flag[0, matched_det_id] = 1
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
def one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for gt_id in range(num_gt):
# skip the following if the groundtruth was matched
if gt_flag[0, gt_id] > 0:
continue
non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
if num_non_zero_in_sigma >= k:
####search for all detections that overlaps with this groundtruth
qualified_tau_candidates = np.where((local_tau_table[
gt_id, :] >= tp) & (det_flag[0, :] == 0))
num_qualified_tau_candidates = qualified_tau_candidates[
0].shape[0]
if num_qualified_tau_candidates == 1:
if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp)
and
(local_sigma_table[gt_id, qualified_tau_candidates] >=
tr)):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]]
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
>= tr):
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]]
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
global_accumulative_recall = global_accumulative_recall + fsc_k
global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
local_accumulative_recall = local_accumulative_recall + fsc_k
local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
def many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for det_id in range(num_det):
# skip the following if the detection was matched
if det_flag[0, det_id] > 0:
continue
non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
if num_non_zero_in_tau >= k:
####search for all detections that overlaps with this groundtruth
qualified_sigma_candidates = np.where((
local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
num_qualified_sigma_candidates = qualified_sigma_candidates[
0].shape[0]
if num_qualified_sigma_candidates == 1:
if ((local_tau_table[qualified_sigma_candidates, det_id] >=
tp) and
(local_sigma_table[qualified_sigma_candidates, det_id]
>= tr)):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, qualified_sigma_candidates] = 1
det_flag[0, det_id] = 1
# recg start
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[
idx]
if ele_gt_id not in global_gt_str[idy]:
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
if pred_str_cur == gt_str_cur:
hit_str_num += 1
break
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
break
# recg end
elif (np.sum(local_tau_table[qualified_sigma_candidates,
det_id]) >= tp):
det_flag[0, det_id] = 1
gt_flag[0, qualified_sigma_candidates] = 1
# recg start
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
if ele_gt_id not in global_gt_str[idy]:
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
if pred_str_cur == gt_str_cur:
hit_str_num += 1
break
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
break
# recg end
global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
global_accumulative_precision = global_accumulative_precision + fsc_k
local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
local_accumulative_precision = local_accumulative_precision + fsc_k
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
single_data = {} single_data = {}
for idx in range(len(global_sigma)):
local_sigma_table = global_sigma[idx]
local_tau_table = global_tau[idx]
num_gt = local_sigma_table.shape[0]
num_det = local_sigma_table.shape[1]
total_num_gt = total_num_gt + num_gt
total_num_det = total_num_det + num_det
local_accumulative_recall = 0
local_accumulative_precision = 0
gt_flag = np.zeros((1, num_gt))
det_flag = np.zeros((1, num_det))
#######first check for one-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
#######then check for one-to-many case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
#######then check for many-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
# fid = open(fid_path, 'a+')
try:
local_precision = local_accumulative_precision / num_det
except ZeroDivisionError:
local_precision = 0
try:
local_recall = local_accumulative_recall / num_gt
except ZeroDivisionError:
local_recall = 0
try:
local_f_score = 2 * local_precision * local_recall / (
local_precision + local_recall)
except ZeroDivisionError:
local_f_score = 0
single_data['sigma'] = global_sigma single_data['sigma'] = global_sigma
single_data['global_tau'] = global_tau single_data['global_tau'] = global_tau
single_data['global_pred_str'] = global_pred_str single_data['global_pred_str'] = global_pred_str
single_data['global_gt_str'] = global_gt_str single_data['global_gt_str'] = global_gt_str
single_data["recall"] = local_recall
single_data['precision'] = local_precision
single_data['f_score'] = local_f_score
return single_data return single_data
...@@ -435,10 +163,10 @@ def combine_results(all_data): ...@@ -435,10 +163,10 @@ def combine_results(all_data):
global_pred_str = [] global_pred_str = []
global_gt_str = [] global_gt_str = []
for data in all_data: for data in all_data:
global_sigma.append(data['sigma'][0]) global_sigma.append(data['sigma'])
global_tau.append(data['global_tau'][0]) global_tau.append(data['global_tau'])
global_pred_str.append(data['global_pred_str'][0]) global_pred_str.append(data['global_pred_str'])
global_gt_str.append(data['global_gt_str'][0]) global_gt_str.append(data['global_gt_str'])
global_accumulative_recall = 0 global_accumulative_recall = 0
global_accumulative_precision = 0 global_accumulative_precision = 0
...@@ -676,6 +404,8 @@ def combine_results(all_data): ...@@ -676,6 +404,8 @@ def combine_results(all_data):
local_accumulative_recall, local_accumulative_precision, local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx) gt_flag, det_flag, idx)
hit_str_count += hit_str_num
try: try:
recall = global_accumulative_recall / total_num_gt recall = global_accumulative_recall / total_num_gt
except ZeroDivisionError: except ZeroDivisionError:
......
...@@ -17,9 +17,11 @@ from __future__ import division ...@@ -17,9 +17,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import cv2 import cv2
import math
import numpy as np import numpy as np
from itertools import groupby from itertools import groupby
from cv2.ximgproc import thinning as thin from skimage.morphology._skeletonize import thin
def get_dict(character_dict_path): def get_dict(character_dict_path):
...@@ -33,39 +35,87 @@ def get_dict(character_dict_path): ...@@ -33,39 +35,87 @@ def get_dict(character_dict_path):
return dict_character return dict_character
def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4): def softmax(logits):
"""
logits: N x d
"""
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def get_keep_pos_idxs(labels, remove_blank=None):
"""
Remove duplicate and get pos idxs of keep items.
The value of keep_blank should be [None, 95].
"""
duplicate_len_list = []
keep_pos_idx_list = []
keep_char_idx_list = []
for k, v_ in groupby(labels):
current_len = len(list(v_))
if k != remove_blank:
current_idx = int(sum(duplicate_len_list) + current_len // 2)
keep_pos_idx_list.append(current_idx)
keep_char_idx_list.append(k)
duplicate_len_list.append(current_len)
return keep_char_idx_list, keep_pos_idx_list
def remove_blank(labels, blank=0):
new_labels = [x for x in labels if x != blank]
return new_labels
def insert_blank(labels, blank=0):
new_labels = [blank]
for l in labels:
new_labels += [l, blank]
return new_labels
def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
"""
CTC greedy (best path) decoder.
"""
raw_str = np.argmax(np.array(probs_seq), axis=1)
remove_blank_in_pos = None if keep_blank_in_idxs else blank
dedup_str, keep_idx_list = get_keep_pos_idxs(
raw_str, remove_blank=remove_blank_in_pos)
dst_str = remove_blank(dedup_str, blank=blank)
return dst_str, keep_idx_list
def instance_ctc_greedy_decoder(gather_info,
logits_map,
keep_blank_in_idxs=True):
"""
gather_info: [[x, y], [x, y] ...]
logits_map: H x W X (n_chars + 1)
"""
_, _, C = logits_map.shape _, _, C = logits_map.shape
ys, xs = zip(*gather_info) ys, xs = zip(*gather_info)
logits_seq = logits_map[list(ys), list(xs)] logits_seq = logits_map[list(ys), list(xs)] # n x 96
probs_seq = logits_seq probs_seq = softmax(logits_seq)
labels = np.argmax(probs_seq, axis=1) dst_str, keep_idx_list = ctc_greedy_decoder(
dst_str = [k for k, v_ in groupby(labels) if k != C - 1] probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs)
detal = len(gather_info) // (pts_num - 1)
keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1]
keep_gather_list = [gather_info[idx] for idx in keep_idx_list] keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
return dst_str, keep_gather_list return dst_str, keep_gather_list
def ctc_decoder_for_image(gather_info_list, def ctc_decoder_for_image(gather_info_list, logits_map,
logits_map, keep_blank_in_idxs=True):
Lexicon_Table,
pts_num=6):
""" """
CTC decoder using multiple processes. CTC decoder using multiple processes.
""" """
decoder_str = [] decoder_results = []
decoder_xys = []
for gather_info in gather_info_list: for gather_info in gather_info_list:
if len(gather_info) < pts_num: res = instance_ctc_greedy_decoder(
continue gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs)
dst_str, xys_list = instance_ctc_greedy_decoder( decoder_results.append(res)
gather_info, logits_map, pts_num=pts_num) return decoder_results
dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
if len(dst_str_readable) < 2:
continue
decoder_str.append(dst_str_readable)
decoder_xys.append(xys_list)
return decoder_str, decoder_xys
def sort_with_direction(pos_list, f_direction): def sort_with_direction(pos_list, f_direction):
...@@ -107,6 +157,58 @@ def sort_with_direction(pos_list, f_direction): ...@@ -107,6 +157,58 @@ def sort_with_direction(pos_list, f_direction):
return sorted_point, np.array(sorted_direction) return sorted_point, np.array(sorted_direction)
def add_id(pos_list, image_id=0):
"""
Add id for gather feature, for inference.
"""
new_list = []
for item in pos_list:
new_list.append((image_id, item[0], item[1]))
return new_list
def sort_and_expand_with_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
# expand along
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
right_dirction = point_direction[point_num - sub_direction_len:, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
left_start = np.array(sorted_list[0])
left_step = left_average_direction / (left_average_len + 1e-6)
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
right_average_len = np.linalg.norm(right_average_direction)
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
append_num = max(
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
left_list = []
right_list = []
for i in range(append_num):
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
'int32').tolist()
if ly < h and lx < w and (ly, lx) not in left_list:
left_list.append((ly, lx))
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
'int32').tolist()
if ry < h and rx < w and (ry, rx) not in right_list:
right_list.append((ry, rx))
all_list = left_list[::-1] + sorted_list + right_list
return all_list
def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map): def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
""" """
f_direction: h x w x 2 f_direction: h x w x 2
...@@ -116,6 +218,7 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map): ...@@ -116,6 +218,7 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
h, w, _ = f_direction.shape h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction) sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
# expand along
point_num = len(sorted_list) point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2) sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :] left_direction = point_direction[:sub_direction_len, :]
...@@ -159,125 +262,271 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map): ...@@ -159,125 +262,271 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
return all_list return all_list
def point_pair2poly(point_pair_list): def generate_pivot_list_curved(p_score,
""" p_char_maps,
Transfer vertical point_pairs into poly point in clockwise. f_direction,
""" score_thresh=0.5,
point_num = len(point_pair_list) * 2 is_expand=True,
point_list = [0] * point_num is_backbone=False,
for idx, point_pair in enumerate(point_pair_list): image_id=0):
point_list[idx] = point_pair[0] """
point_list[point_num - 1 - idx] = point_pair[1] return center point and end point of TCL instance; filter with the char maps;
return np.array(point_list).reshape(-1, 2) """
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.): p_tcl_map = (p_score > score_thresh) * 1.0
ratio_pair = np.array( skeleton_map = thin(p_tcl_map)
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32) instance_count, instance_label_map = cv2.connectedComponents(
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair skeleton_map.astype(np.uint8), connectivity=8)
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) # get TCL Instance
all_pos_yxs = []
center_pos_yxs = []
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3): end_points_yxs = []
""" instance_center_pos_yxs = []
expand poly along width. if instance_count > 0:
""" for instance_id in range(1, instance_count):
point_num = poly.shape[0] pos_list = []
left_quad = np.array( ys, xs = np.where(instance_label_map == instance_id)
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) pos_list = list(zip(ys, xs))
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) ### FIX-ME, eliminate outlier
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0) if len(pos_list) < 3:
right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
poly[point_num // 2] = right_quad_expand[2]
return poly
def restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w,
src_h, valid_set):
poly_list = []
keep_str_list = []
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
if len(keep_str) < 2:
print('--> too short, {}'.format(keep_str))
continue continue
offset_expand = 1.0 if is_expand:
if valid_set == 'totaltext': pos_list_sorted = sort_and_expand_with_direction_v2(
offset_expand = 1.2 pos_list, f_direction, p_tcl_map)
point_pair_list = []
for y, x in yx_center_line:
offset = p_border[:, y, x].reshape(2, 2) * offset_expand
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
detected_poly = point_pair2poly(point_pair_list)
detected_poly = expand_poly_along_width(
detected_poly, shrink_ratio_of_width=0.2)
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
keep_str_list.append(keep_str)
if valid_set == 'partvgg':
middle_point = len(detected_poly) // 2
detected_poly = detected_poly[
[0, middle_point - 1, middle_point, -1], :]
poly_list.append(detected_poly)
elif valid_set == 'totaltext':
poly_list.append(detected_poly)
else: else:
print('--> Not supported format.') pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
exit(-1) all_pos_yxs.append(pos_list_sorted)
return poly_list, keep_str_list
# use decoder to filter backgroud points.
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
for decoded_str, keep_yxs_list in decode_res:
if is_backbone:
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
instance_center_pos_yxs.append(keep_yxs_list_with_id)
else:
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
center_pos_yxs.extend(keep_yxs_list)
def generate_pivot_list(p_score, if is_backbone:
return instance_center_pos_yxs
else:
return center_pos_yxs, end_points_yxs
def generate_pivot_list_horizontal(p_score,
p_char_maps, p_char_maps,
f_direction, f_direction,
Lexicon_Table, score_thresh=0.5,
score_thresh=0.5): is_backbone=False,
image_id=0):
""" """
return center point and end point of TCL instance; filter with the char maps; return center point and end point of TCL instance; filter with the char maps;
""" """
p_score = p_score[0] p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0) f_direction = f_direction.transpose(1, 2, 0)
ret, p_tcl_map = cv2.threshold(p_score, score_thresh, 255, p_tcl_map_bi = (p_score > score_thresh) * 1.0
cv2.THRESH_BINARY)
skeleton_map = thin(p_tcl_map.astype('uint8'))
instance_count, instance_label_map = cv2.connectedComponents( instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map, connectivity=8) p_tcl_map_bi.astype(np.uint8), connectivity=8)
# get TCL Instance # get TCL Instance
all_pos_yxs = [] all_pos_yxs = []
center_pos_yxs = []
end_points_yxs = []
instance_center_pos_yxs = []
if instance_count > 0: if instance_count > 0:
for instance_id in range(1, instance_count): for instance_id in range(1, instance_count):
pos_list = [] pos_list = []
ys, xs = np.where(instance_label_map == instance_id) ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs)) pos_list = list(zip(ys, xs))
if len(pos_list) < 3: ### FIX-ME, eliminate outlier
if len(pos_list) < 5:
continue continue
pos_list_sorted = sort_and_expand_with_direction_v2( # add rule here
pos_list, f_direction, p_tcl_map) main_direction = extract_main_direction(pos_list,
f_direction) # y x
reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
is_h_angle = abs(np.sum(
main_direction * reference_directin)) < math.cos(math.pi / 180 *
70)
point_yxs = np.array(pos_list)
max_y, max_x = np.max(point_yxs, axis=0)
min_y, min_x = np.min(point_yxs, axis=0)
is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x)
pos_list_final = []
if is_h_len:
xs = np.unique(xs)
for x in xs:
ys = instance_label_map[:, x].copy().reshape((-1, ))
y = int(np.where(ys == instance_id)[0].mean())
pos_list_final.append((y, x))
else:
ys = np.unique(ys)
for y in ys:
xs = instance_label_map[y, :].copy().reshape((-1, ))
x = int(np.where(xs == instance_id)[0].mean())
pos_list_final.append((y, x))
pos_list_sorted, _ = sort_with_direction(pos_list_final,
f_direction)
all_pos_yxs.append(pos_list_sorted) all_pos_yxs.append(pos_list_sorted)
# use decoder to filter backgroud points.
p_char_maps = p_char_maps.transpose([1, 2, 0]) p_char_maps = p_char_maps.transpose([1, 2, 0])
decoded_str, keep_yxs_list = ctc_decoder_for_image( decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table) all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
return keep_yxs_list, decoded_str for decoded_str, keep_yxs_list in decode_res:
if is_backbone:
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
instance_center_pos_yxs.append(keep_yxs_list_with_id)
else:
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
center_pos_yxs.extend(keep_yxs_list)
if is_backbone:
return instance_center_pos_yxs
else:
return center_pos_yxs, end_points_yxs
def generate_pivot_list(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
is_curved=True,
image_id=0):
"""
Warp all the function together.
"""
if is_curved:
return generate_pivot_list_curved(
p_score,
p_char_maps,
f_direction,
score_thresh=score_thresh,
is_expand=True,
is_backbone=is_backbone,
image_id=image_id)
else:
return generate_pivot_list_horizontal(
p_score,
p_char_maps,
f_direction,
score_thresh=score_thresh,
is_backbone=is_backbone,
image_id=image_id)
# for refine module
def extract_main_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
pos_list = np.array(pos_list)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
average_direction = average_direction / (
np.linalg.norm(average_direction) + 1e-6)
return average_direction
def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
"""
pos_list_full = np.array(pos_list).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
return sorted_list
def sort_by_direction_with_image_id(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def sort_part_with_direction(pos_list_full, point_direction):
pos_list_full = np.array(pos_list_full).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = np.array(point_direction).reshape(-1, 2)
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
return sorted_list, sorted_direction
pos_list = np.array(pos_list).reshape(-1, 3)
point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
point_direction)
point_num = len(sorted_point)
if point_num >= 16:
middle_num = point_num // 2
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
first_part_point, first_point_direction)
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
last_part_point, last_point_direction)
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
return sorted_point
def generate_pivot_list_tt_inference(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
is_curved=True,
image_id=0):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map = (p_score > score_thresh) * 1.0
skeleton_map = thin(p_tcl_map)
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
### FIX-ME, eliminate outlier
if len(pos_list) < 3:
continue
pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, p_tcl_map)
pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
all_pos_yxs.append(pos_list_sorted_with_id)
return all_pos_yxs
...@@ -151,7 +151,7 @@ if __name__ == "__main__": ...@@ -151,7 +151,7 @@ if __name__ == "__main__":
src_im = utility.draw_e2e_res(points, strs, image_file) src_im = utility.draw_e2e_res(points, strs, image_file)
img_name_pure = os.path.split(image_file)[-1] img_name_pure = os.path.split(image_file)[-1]
img_path = os.path.join(draw_img_save, img_path = os.path.join(draw_img_save,
"e2e_res_{}_pgnet".format(img_name_pure)) "e2e_res_{}".format(img_name_pure))
cv2.imwrite(img_path, src_im) cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path)) logger.info("The visualized image saved in {}".format(img_path))
if count > 1: if count > 1:
......
2.0,165.0,20.0,167.0,39.0,170.0,57.0,173.0,76.0,176.0,94.0,179.0,113.0,182.0,109.0,218.0,90.0,215.0,72.0,213.0,54.0,210.0,36.0,208.0,18.0,205.0,0.0,203.0 izza
2.0,411.0,30.0,412.0,58.0,414.0,87.0,416.0,115.0,418.0,143.0,420.0,172.0,422.0,172.0,476.0,143.0,474.0,114.0,472.0,86.0,471.0,57.0,469.0,28.0,467.0,0.0,466.0 ISA
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册