From b5de79b2c9ab516030765ad5df2b137b3ad95e89 Mon Sep 17 00:00:00 2001 From: Jethong <1147925384@qq.com> Date: Thu, 15 Apr 2021 19:43:31 +0800 Subject: [PATCH] refine code --- .../utils/e2e_utils/extract_textpoint_slow.py | 4 +++- ppocr/utils/e2e_utils/pgnet_pp_utils.py | 21 +------------------ 2 files changed, 4 insertions(+), 21 deletions(-) diff --git a/ppocr/utils/e2e_utils/extract_textpoint_slow.py b/ppocr/utils/e2e_utils/extract_textpoint_slow.py index db0c30e6..ace46fba 100644 --- a/ppocr/utils/e2e_utils/extract_textpoint_slow.py +++ b/ppocr/utils/e2e_utils/extract_textpoint_slow.py @@ -342,6 +342,7 @@ def generate_pivot_list_curved(p_score, center_pos_yxs = [] end_points_yxs = [] instance_center_pos_yxs = [] + pred_strs = [] if instance_count > 0: for instance_id in range(1, instance_count): pos_list = [] @@ -367,12 +368,13 @@ def generate_pivot_list_curved(p_score, if is_backbone: keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id) instance_center_pos_yxs.append(keep_yxs_list_with_id) + pred_strs.append(decoded_str) else: end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1])) center_pos_yxs.extend(keep_yxs_list) if is_backbone: - return instance_center_pos_yxs + return pred_strs, instance_center_pos_yxs else: return center_pos_yxs, end_points_yxs diff --git a/ppocr/utils/e2e_utils/pgnet_pp_utils.py b/ppocr/utils/e2e_utils/pgnet_pp_utils.py index 64bfd372..db1654f3 100644 --- a/ppocr/utils/e2e_utils/pgnet_pp_utils.py +++ b/ppocr/utils/e2e_utils/pgnet_pp_utils.py @@ -85,32 +85,13 @@ class PGNet_PostProcess(object): p_char = p_char[0] src_h, src_w, ratio_h, ratio_w = self.shape_list[0] is_curved = self.valid_set == "totaltext" - instance_yxs_list = generate_pivot_list_slow( + char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow( p_score, p_char, p_direction, score_thresh=self.score_thresh, is_backbone=True, is_curved=is_curved) - p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0)) - char_seq_idx_set = [] - for i in range(len(instance_yxs_list)): - gather_info_lod = paddle.to_tensor(instance_yxs_list[i]) - f_char_map = paddle.transpose(p_char, [0, 2, 3, 1]) - feature_seq = paddle.gather_nd(f_char_map, gather_info_lod) - feature_seq = np.expand_dims(feature_seq.numpy(), axis=0) - feature_len = [len(feature_seq[0])] - featyre_seq = paddle.to_tensor(feature_seq) - feature_len = np.array([feature_len]).astype(np.int64) - length = paddle.to_tensor(feature_len) - seq_pred = paddle.fluid.layers.ctc_greedy_decoder( - input=featyre_seq, blank=36, input_length=length) - seq_pred_str = seq_pred[0].numpy().tolist()[0] - seq_len = seq_pred[1].numpy()[0][0] - temp_t = [] - for c in seq_pred_str[:seq_len]: - temp_t.append(c) - char_seq_idx_set.append(temp_t) seq_strs = [] for char_idx_set in char_seq_idx_set: pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set]) -- GitLab