提交 4c0b0873 编写于 作者: W wangjingyeye

update pgnet

上级 929b4f45
...@@ -33,7 +33,7 @@ Architecture: ...@@ -33,7 +33,7 @@ Architecture:
name: PGFPN name: PGFPN
Head: Head:
name: PGHead name: PGHead
tcc_channels: 37 # the length of character dict character_dict_path: ppocr/utils/ic15_dict.txt # the same as Global:character_dict_path
Loss: Loss:
name: PGLoss name: PGLoss
...@@ -58,7 +58,7 @@ PostProcess: ...@@ -58,7 +58,7 @@ PostProcess:
name: PGPostProcess name: PGPostProcess
score_thresh: 0.5 score_thresh: 0.5
mode: fast # fast or slow two ways mode: fast # fast or slow two ways
tcc_type: v3 # same as PGProcessTrain: tcc_type point_gather_mode: v3 # same as PGProcessTrain: point_gather_mode
Metric: Metric:
name: E2EMetric name: E2EMetric
...@@ -85,7 +85,7 @@ Train: ...@@ -85,7 +85,7 @@ Train:
min_crop_size: 24 min_crop_size: 24
min_text_size: 4 min_text_size: 4
max_text_size: 512 max_text_size: 512
tcc_type: v3 # two ways, v2 is original code, v3 is updated code point_gather_mode: v3 # two ways, v2 is original code, v3 is updated code
- KeepKeys: - KeepKeys:
keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order
loader: loader:
......
...@@ -33,7 +33,7 @@ class PGProcessTrain(object): ...@@ -33,7 +33,7 @@ class PGProcessTrain(object):
min_crop_size=24, min_crop_size=24,
min_text_size=4, min_text_size=4,
max_text_size=512, max_text_size=512,
tcc_type='v3', point_gather_mode='v3',
**kwargs): **kwargs):
self.tcl_len = tcl_len self.tcl_len = tcl_len
self.max_text_length = max_text_length self.max_text_length = max_text_length
...@@ -45,7 +45,7 @@ class PGProcessTrain(object): ...@@ -45,7 +45,7 @@ class PGProcessTrain(object):
self.min_text_size = min_text_size self.min_text_size = min_text_size
self.max_text_size = max_text_size self.max_text_size = max_text_size
self.use_resize = use_resize self.use_resize = use_resize
self.tcc_type = tcc_type self.point_gather_mode = point_gather_mode
self.Lexicon_Table = self.get_dict(character_dict_path) self.Lexicon_Table = self.get_dict(character_dict_path)
self.pad_num = len(self.Lexicon_Table) self.pad_num = len(self.Lexicon_Table)
self.img_id = 0 self.img_id = 0
...@@ -531,7 +531,7 @@ class PGProcessTrain(object): ...@@ -531,7 +531,7 @@ class PGProcessTrain(object):
average_shrink_height = self.calculate_average_height( average_shrink_height = self.calculate_average_height(
stcl_quads) stcl_quads)
if self.tcc_type == 'v3': if self.point_gather_mode == 'v3':
self.f_direction = direction_map[:, :, :-1].copy() self.f_direction = direction_map[:, :, :-1].copy()
pos_res = self.fit_and_gather_tcl_points_v3( pos_res = self.fit_and_gather_tcl_points_v3(
min_area_quad, min_area_quad,
...@@ -545,7 +545,7 @@ class PGProcessTrain(object): ...@@ -545,7 +545,7 @@ class PGProcessTrain(object):
continue continue
pos_l, pos_m = pos_res[0], pos_res[1] pos_l, pos_m = pos_res[0], pos_res[1]
elif self.tcc_type == 'v2': elif self.point_gather_mode == 'v2':
pos_l, pos_m = self.fit_and_gather_tcl_points_v2( pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
min_area_quad, min_area_quad,
poly, poly,
......
...@@ -66,8 +66,17 @@ class PGHead(nn.Layer): ...@@ -66,8 +66,17 @@ class PGHead(nn.Layer):
""" """
""" """
def __init__(self, in_channels, tcc_channels=37, **kwargs): def __init__(self,
in_channels,
character_dict_path='ppocr/utils/ic15_dict.txt',
**kwargs):
super(PGHead, self).__init__() super(PGHead, self).__init__()
# get character_length
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
character_length = len(lines) + 1
self.conv_f_score1 = ConvBNLayer( self.conv_f_score1 = ConvBNLayer(
in_channels=in_channels, in_channels=in_channels,
out_channels=64, out_channels=64,
...@@ -178,7 +187,7 @@ class PGHead(nn.Layer): ...@@ -178,7 +187,7 @@ class PGHead(nn.Layer):
name="conv_f_char{}".format(5)) name="conv_f_char{}".format(5))
self.conv3 = nn.Conv2D( self.conv3 = nn.Conv2D(
in_channels=256, in_channels=256,
out_channels=tcc_channels, out_channels=character_length,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
......
...@@ -31,12 +31,12 @@ class PGPostProcess(object): ...@@ -31,12 +31,12 @@ class PGPostProcess(object):
""" """
def __init__(self, character_dict_path, valid_set, score_thresh, mode, def __init__(self, character_dict_path, valid_set, score_thresh, mode,
tcc_type, **kwargs): point_gather_mode, **kwargs):
self.character_dict_path = character_dict_path self.character_dict_path = character_dict_path
self.valid_set = valid_set self.valid_set = valid_set
self.score_thresh = score_thresh self.score_thresh = score_thresh
self.mode = mode self.mode = mode
self.tcc_type = tcc_type self.point_gather_mode = point_gather_mode
# c++ la-nms is faster, but only support python 3.5 # c++ la-nms is faster, but only support python 3.5
self.is_python35 = False self.is_python35 = False
...@@ -50,7 +50,7 @@ class PGPostProcess(object): ...@@ -50,7 +50,7 @@ class PGPostProcess(object):
self.score_thresh, self.score_thresh,
outs_dict, outs_dict,
shape_list, shape_list,
tcc_type=self.tcc_type) point_gather_mode=self.point_gather_mode)
if self.mode == 'fast': if self.mode == 'fast':
data = post.pg_postprocess_fast() data = post.pg_postprocess_fast()
else: else:
......
...@@ -91,9 +91,9 @@ def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True): ...@@ -91,9 +91,9 @@ def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
def instance_ctc_greedy_decoder(gather_info, def instance_ctc_greedy_decoder(gather_info,
logits_map, logits_map,
pts_num=4, pts_num=4,
tcc_type='v3'): point_gather_mode='v3'):
_, _, C = logits_map.shape _, _, C = logits_map.shape
if tcc_type == 'v3': if point_gather_mode == 'v3':
insert_num = 0 insert_num = 0
gather_info = np.array(gather_info) gather_info = np.array(gather_info)
length = len(gather_info) - 1 length = len(gather_info) - 1
...@@ -130,7 +130,7 @@ def ctc_decoder_for_image(gather_info_list, ...@@ -130,7 +130,7 @@ def ctc_decoder_for_image(gather_info_list,
logits_map, logits_map,
Lexicon_Table, Lexicon_Table,
pts_num=6, pts_num=6,
tcc_type='v3'): point_gather_mode='v3'):
""" """
CTC decoder using multiple processes. CTC decoder using multiple processes.
""" """
...@@ -140,7 +140,7 @@ def ctc_decoder_for_image(gather_info_list, ...@@ -140,7 +140,7 @@ def ctc_decoder_for_image(gather_info_list,
if len(gather_info) < pts_num: if len(gather_info) < pts_num:
continue continue
dst_str, xys_list = instance_ctc_greedy_decoder( dst_str, xys_list = instance_ctc_greedy_decoder(
gather_info, logits_map, pts_num=pts_num, tcc_type='v3') gather_info, logits_map, pts_num=pts_num, point_gather_mode='v3')
dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str]) dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
if len(dst_str_readable) < 2: if len(dst_str_readable) < 2:
continue continue
...@@ -383,7 +383,7 @@ def generate_pivot_list_fast(p_score, ...@@ -383,7 +383,7 @@ def generate_pivot_list_fast(p_score,
f_direction, f_direction,
Lexicon_Table, Lexicon_Table,
score_thresh=0.5, score_thresh=0.5,
tcc_type='v3'): point_gather_mode='v3'):
""" """
return center point and end point of TCL instance; filter with the char maps; return center point and end point of TCL instance; filter with the char maps;
""" """
...@@ -414,7 +414,7 @@ def generate_pivot_list_fast(p_score, ...@@ -414,7 +414,7 @@ def generate_pivot_list_fast(p_score,
all_pos_yxs, all_pos_yxs,
logits_map=p_char_maps, logits_map=p_char_maps,
Lexicon_Table=Lexicon_Table, Lexicon_Table=Lexicon_Table,
tcc_type='v3') point_gather_mode='v3')
return keep_yxs_list, decoded_str return keep_yxs_list, decoded_str
......
...@@ -34,13 +34,13 @@ class PGNet_PostProcess(object): ...@@ -34,13 +34,13 @@ class PGNet_PostProcess(object):
score_thresh, score_thresh,
outs_dict, outs_dict,
shape_list, shape_list,
tcc_type='v3'): point_gather_mode='v3'):
self.Lexicon_Table = get_dict(character_dict_path) self.Lexicon_Table = get_dict(character_dict_path)
self.valid_set = valid_set self.valid_set = valid_set
self.score_thresh = score_thresh self.score_thresh = score_thresh
self.outs_dict = outs_dict self.outs_dict = outs_dict
self.shape_list = shape_list self.shape_list = shape_list
self.tcc_type = tcc_type self.point_gather_mode = point_gather_mode
def pg_postprocess_fast(self): def pg_postprocess_fast(self):
p_score = self.outs_dict['f_score'] p_score = self.outs_dict['f_score']
...@@ -65,7 +65,7 @@ class PGNet_PostProcess(object): ...@@ -65,7 +65,7 @@ class PGNet_PostProcess(object):
p_direction, p_direction,
self.Lexicon_Table, self.Lexicon_Table,
score_thresh=self.score_thresh, score_thresh=self.score_thresh,
tcc_type=self.tcc_type) point_gather_mode=self.point_gather_mode)
poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs, poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs,
p_border, ratio_w, ratio_h, p_border, ratio_w, ratio_h,
src_w, src_h, self.valid_set) src_w, src_h, self.valid_set)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册