提交 b5de79b2 编写于 作者: J Jethong

refine code

上级 c5f33b00
...@@ -342,6 +342,7 @@ def generate_pivot_list_curved(p_score, ...@@ -342,6 +342,7 @@ def generate_pivot_list_curved(p_score,
center_pos_yxs = [] center_pos_yxs = []
end_points_yxs = [] end_points_yxs = []
instance_center_pos_yxs = [] instance_center_pos_yxs = []
pred_strs = []
if instance_count > 0: if instance_count > 0:
for instance_id in range(1, instance_count): for instance_id in range(1, instance_count):
pos_list = [] pos_list = []
...@@ -367,12 +368,13 @@ def generate_pivot_list_curved(p_score, ...@@ -367,12 +368,13 @@ def generate_pivot_list_curved(p_score,
if is_backbone: if is_backbone:
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id) keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
instance_center_pos_yxs.append(keep_yxs_list_with_id) instance_center_pos_yxs.append(keep_yxs_list_with_id)
pred_strs.append(decoded_str)
else: else:
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1])) end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
center_pos_yxs.extend(keep_yxs_list) center_pos_yxs.extend(keep_yxs_list)
if is_backbone: if is_backbone:
return instance_center_pos_yxs return pred_strs, instance_center_pos_yxs
else: else:
return center_pos_yxs, end_points_yxs return center_pos_yxs, end_points_yxs
......
...@@ -85,32 +85,13 @@ class PGNet_PostProcess(object): ...@@ -85,32 +85,13 @@ class PGNet_PostProcess(object):
p_char = p_char[0] p_char = p_char[0]
src_h, src_w, ratio_h, ratio_w = self.shape_list[0] src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
is_curved = self.valid_set == "totaltext" is_curved = self.valid_set == "totaltext"
instance_yxs_list = generate_pivot_list_slow( char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow(
p_score, p_score,
p_char, p_char,
p_direction, p_direction,
score_thresh=self.score_thresh, score_thresh=self.score_thresh,
is_backbone=True, is_backbone=True,
is_curved=is_curved) is_curved=is_curved)
p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0))
char_seq_idx_set = []
for i in range(len(instance_yxs_list)):
gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
feature_seq = paddle.gather_nd(f_char_map, gather_info_lod)
feature_seq = np.expand_dims(feature_seq.numpy(), axis=0)
feature_len = [len(feature_seq[0])]
featyre_seq = paddle.to_tensor(feature_seq)
feature_len = np.array([feature_len]).astype(np.int64)
length = paddle.to_tensor(feature_len)
seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
input=featyre_seq, blank=36, input_length=length)
seq_pred_str = seq_pred[0].numpy().tolist()[0]
seq_len = seq_pred[1].numpy()[0][0]
temp_t = []
for c in seq_pred_str[:seq_len]:
temp_t.append(c)
char_seq_idx_set.append(temp_t)
seq_strs = [] seq_strs = []
for char_idx_set in char_seq_idx_set: for char_idx_set in char_seq_idx_set:
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set]) pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册