未验证 提交 d022b269 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #7420 from wangjingyeye/dyg_db

update pgnet
...@@ -13,6 +13,7 @@ Global: ...@@ -13,6 +13,7 @@ Global:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
infer_img: infer_img:
infer_visual_type: EN # two mode: EN is for english datasets, CN is for chinese datasets
valid_set: totaltext # two mode: totaltext valid curved words, partvgg valid non-curved words valid_set: totaltext # two mode: totaltext valid curved words, partvgg valid non-curved words
save_res_path: ./output/pgnet_r50_vd_totaltext/predicts_pgnet.txt save_res_path: ./output/pgnet_r50_vd_totaltext/predicts_pgnet.txt
character_dict_path: ppocr/utils/ic15_dict.txt character_dict_path: ppocr/utils/ic15_dict.txt
...@@ -32,6 +33,7 @@ Architecture: ...@@ -32,6 +33,7 @@ Architecture:
name: PGFPN name: PGFPN
Head: Head:
name: PGHead name: PGHead
character_dict_path: ppocr/utils/ic15_dict.txt # the same as Global:character_dict_path
Loss: Loss:
name: PGLoss name: PGLoss
...@@ -45,16 +47,18 @@ Optimizer: ...@@ -45,16 +47,18 @@ Optimizer:
beta1: 0.9 beta1: 0.9
beta2: 0.999 beta2: 0.999
lr: lr:
name: Cosine
learning_rate: 0.001 learning_rate: 0.001
warmup_epoch: 50
regularizer: regularizer:
name: 'L2' name: 'L2'
factor: 0 factor: 0.0001
PostProcess: PostProcess:
name: PGPostProcess name: PGPostProcess
score_thresh: 0.5 score_thresh: 0.5
mode: fast # fast or slow two ways mode: fast # fast or slow two ways
point_gather_mode: align # same as PGProcessTrain: point_gather_mode
Metric: Metric:
name: E2EMetric name: E2EMetric
...@@ -76,9 +80,12 @@ Train: ...@@ -76,9 +80,12 @@ Train:
- E2ELabelEncodeTrain: - E2ELabelEncodeTrain:
- PGProcessTrain: - PGProcessTrain:
batch_size: 14 # same as loader: batch_size_per_card batch_size: 14 # same as loader: batch_size_per_card
use_resize: True
use_random_crop: False
min_crop_size: 24 min_crop_size: 24
min_text_size: 4 min_text_size: 4
max_text_size: 512 max_text_size: 512
point_gather_mode: align # two mode: align and none, align mode is better than none mode
- KeepKeys: - KeepKeys:
keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order
loader: loader:
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
import math import math
import cv2 import cv2
import numpy as np import numpy as np
from skimage.morphology._skeletonize import thin
from ppocr.utils.e2e_utils.extract_textpoint_fast import sort_and_expand_with_direction_v2
__all__ = ['PGProcessTrain'] __all__ = ['PGProcessTrain']
...@@ -26,17 +28,24 @@ class PGProcessTrain(object): ...@@ -26,17 +28,24 @@ class PGProcessTrain(object):
max_text_nums, max_text_nums,
tcl_len, tcl_len,
batch_size=14, batch_size=14,
use_resize=True,
use_random_crop=False,
min_crop_size=24, min_crop_size=24,
min_text_size=4, min_text_size=4,
max_text_size=512, max_text_size=512,
point_gather_mode=None,
**kwargs): **kwargs):
self.tcl_len = tcl_len self.tcl_len = tcl_len
self.max_text_length = max_text_length self.max_text_length = max_text_length
self.max_text_nums = max_text_nums self.max_text_nums = max_text_nums
self.batch_size = batch_size self.batch_size = batch_size
self.min_crop_size = min_crop_size if use_random_crop is True:
self.min_crop_size = min_crop_size
self.use_random_crop = use_random_crop
self.min_text_size = min_text_size self.min_text_size = min_text_size
self.max_text_size = max_text_size self.max_text_size = max_text_size
self.use_resize = use_resize
self.point_gather_mode = point_gather_mode
self.Lexicon_Table = self.get_dict(character_dict_path) self.Lexicon_Table = self.get_dict(character_dict_path)
self.pad_num = len(self.Lexicon_Table) self.pad_num = len(self.Lexicon_Table)
self.img_id = 0 self.img_id = 0
...@@ -282,6 +291,95 @@ class PGProcessTrain(object): ...@@ -282,6 +291,95 @@ class PGProcessTrain(object):
pos_m[:keep] = 1.0 pos_m[:keep] = 1.0
return pos_l, pos_m return pos_l, pos_m
def fit_and_gather_tcl_points_v3(self,
min_area_quad,
poly,
max_h,
max_w,
fixed_point_num=64,
img_id=0,
reference_height=3):
"""
Find the center point of poly as key_points, then fit and gather.
"""
det_mask = np.zeros((int(max_h / self.ds_ratio),
int(max_w / self.ds_ratio))).astype(np.float32)
# score_big_map
cv2.fillPoly(det_mask,
np.round(poly / self.ds_ratio).astype(np.int32), 1.0)
det_mask = cv2.resize(
det_mask, dsize=None, fx=self.ds_ratio, fy=self.ds_ratio)
det_mask = np.array(det_mask > 1e-3, dtype='float32')
f_direction = self.f_direction
skeleton_map = thin(det_mask.astype(np.uint8))
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
ys, xs = np.where(instance_label_map == 1)
pos_list = list(zip(ys, xs))
if len(pos_list) < 3:
return None
pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, det_mask)
pos_list_sorted = np.array(pos_list_sorted)
length = len(pos_list_sorted) - 1
insert_num = 0
for index in range(length):
stride_y = np.abs(pos_list_sorted[index + insert_num][0] -
pos_list_sorted[index + 1 + insert_num][0])
stride_x = np.abs(pos_list_sorted[index + insert_num][1] -
pos_list_sorted[index + 1 + insert_num][1])
max_points = int(max(stride_x, stride_y))
stride = (pos_list_sorted[index + insert_num] -
pos_list_sorted[index + 1 + insert_num]) / (max_points)
insert_num_temp = max_points - 1
for i in range(int(insert_num_temp)):
insert_value = pos_list_sorted[index + insert_num] - (i + 1
) * stride
insert_index = index + i + 1 + insert_num
pos_list_sorted = np.insert(
pos_list_sorted, insert_index, insert_value, axis=0)
insert_num += insert_num_temp
pos_info = np.array(pos_list_sorted).reshape(-1, 2).astype(
np.float32) # xy-> yx
point_num = len(pos_info)
if point_num > fixed_point_num:
keep_ids = [
int((point_num * 1.0 / fixed_point_num) * x)
for x in range(fixed_point_num)
]
pos_info = pos_info[keep_ids, :]
keep = int(min(len(pos_info), fixed_point_num))
reference_width = (np.abs(poly[0, 0, 0] - poly[-1, 1, 0]) +
np.abs(poly[0, 3, 0] - poly[-1, 2, 0])) // 2
if np.random.rand() < 1:
dh = (np.random.rand(keep) - 0.5) * reference_height
offset = np.random.rand() - 0.5
dw = np.array([[0, offset * reference_width * 0.2]])
random_float_h = np.array([1, 0]).reshape([1, 2]) * dh.reshape(
[keep, 1])
random_float_w = dw.repeat(keep, axis=0)
pos_info += random_float_h
pos_info += random_float_w
pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
# padding to fixed length
pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id
pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
pos_m[:keep] = 1.0
return pos_l, pos_m
def generate_direction_map(self, poly_quads, n_char, direction_map): def generate_direction_map(self, poly_quads, n_char, direction_map):
""" """
""" """
...@@ -334,6 +432,7 @@ class PGProcessTrain(object): ...@@ -334,6 +432,7 @@ class PGProcessTrain(object):
""" """
Generate polygon. Generate polygon.
""" """
self.ds_ratio = ds_ratio
score_map_big = np.zeros( score_map_big = np.zeros(
( (
h, h,
...@@ -384,7 +483,6 @@ class PGProcessTrain(object): ...@@ -384,7 +483,6 @@ class PGProcessTrain(object):
text_label = text_strs[poly_idx] text_label = text_strs[poly_idx]
text_label = self.prepare_text_label(text_label, text_label = self.prepare_text_label(text_label,
self.Lexicon_Table) self.Lexicon_Table)
text_label_index_list = [[self.Lexicon_Table.index(c_)] text_label_index_list = [[self.Lexicon_Table.index(c_)]
for c_ in text_label for c_ in text_label
if c_ in self.Lexicon_Table] if c_ in self.Lexicon_Table]
...@@ -432,14 +530,30 @@ class PGProcessTrain(object): ...@@ -432,14 +530,30 @@ class PGProcessTrain(object):
# pos info # pos info
average_shrink_height = self.calculate_average_height( average_shrink_height = self.calculate_average_height(
stcl_quads) stcl_quads)
pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
min_area_quad, if self.point_gather_mode == 'align':
poly, self.f_direction = direction_map[:, :, :-1].copy()
max_h=h, pos_res = self.fit_and_gather_tcl_points_v3(
max_w=w, min_area_quad,
fixed_point_num=64, stcl_quads,
img_id=self.img_id, max_h=h,
reference_height=average_shrink_height) max_w=w,
fixed_point_num=64,
img_id=self.img_id,
reference_height=average_shrink_height)
if pos_res is None:
continue
pos_l, pos_m = pos_res[0], pos_res[1]
else:
pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
min_area_quad,
poly,
max_h=h,
max_w=w,
fixed_point_num=64,
img_id=self.img_id,
reference_height=average_shrink_height)
label_l = text_label_index_list label_l = text_label_index_list
if len(text_label_index_list) < 2: if len(text_label_index_list) < 2:
...@@ -770,27 +884,41 @@ class PGProcessTrain(object): ...@@ -770,27 +884,41 @@ class PGProcessTrain(object):
text_polys[:, :, 0] *= asp_wx text_polys[:, :, 0] *= asp_wx
text_polys[:, :, 1] *= asp_hy text_polys[:, :, 1] *= asp_hy
h, w, _ = im.shape if self.use_resize is True:
if max(h, w) > 2048: ori_h, ori_w, _ = im.shape
rd_scale = 2048.0 / max(h, w) if max(ori_h, ori_w) < 200:
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) ratio = 200 / max(ori_h, ori_w)
text_polys *= rd_scale im = cv2.resize(im, (int(ori_w * ratio), int(ori_h * ratio)))
h, w, _ = im.shape text_polys[:, :, 0] *= ratio
if min(h, w) < 16: text_polys[:, :, 1] *= ratio
return None
if max(ori_h, ori_w) > 512:
# no background ratio = 512 / max(ori_h, ori_w)
im, text_polys, text_tags, hv_tags, text_strs = self.crop_area( im = cv2.resize(im, (int(ori_w * ratio), int(ori_h * ratio)))
im, text_polys[:, :, 0] *= ratio
text_polys, text_polys[:, :, 1] *= ratio
text_tags, elif self.use_random_crop is True:
hv_tags, h, w, _ = im.shape
text_strs, if max(h, w) > 2048:
crop_background=False) rd_scale = 2048.0 / max(h, w)
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
text_polys *= rd_scale
h, w, _ = im.shape
if min(h, w) < 16:
return None
# no background
im, text_polys, text_tags, hv_tags, text_strs = self.crop_area(
im,
text_polys,
text_tags,
hv_tags,
text_strs,
crop_background=False)
if text_polys.shape[0] == 0: if text_polys.shape[0] == 0:
return None return None
# # continue for all ignore case # continue for all ignore case
if np.sum((text_tags * 1.0)) >= text_tags.size: if np.sum((text_tags * 1.0)) >= text_tags.size:
return None return None
new_h, new_w, _ = im.shape new_h, new_w, _ = im.shape
......
...@@ -89,12 +89,13 @@ class PGLoss(nn.Layer): ...@@ -89,12 +89,13 @@ class PGLoss(nn.Layer):
tcl_pos = paddle.reshape(tcl_pos, [-1, 3]) tcl_pos = paddle.reshape(tcl_pos, [-1, 3])
tcl_pos = paddle.cast(tcl_pos, dtype=int) tcl_pos = paddle.cast(tcl_pos, dtype=int)
f_tcl_char = paddle.gather_nd(f_char, tcl_pos) f_tcl_char = paddle.gather_nd(f_char, tcl_pos)
f_tcl_char = paddle.reshape(f_tcl_char, f_tcl_char = paddle.reshape(
[-1, 64, 37]) # len(Lexicon_Table)+1 f_tcl_char, [-1, 64, self.pad_num + 1]) # len(Lexicon_Table)+1
f_tcl_char_fg, f_tcl_char_bg = paddle.split(f_tcl_char, [36, 1], axis=2) f_tcl_char_fg, f_tcl_char_bg = paddle.split(
f_tcl_char, [self.pad_num, 1], axis=2)
f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0 f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0
b, c, l = tcl_mask.shape b, c, l = tcl_mask.shape
tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, 36 * l]) tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, self.pad_num * l])
tcl_mask_fg.stop_gradient = True tcl_mask_fg.stop_gradient = True
f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * ( f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * (
-20.0) -20.0)
......
...@@ -66,8 +66,17 @@ class PGHead(nn.Layer): ...@@ -66,8 +66,17 @@ class PGHead(nn.Layer):
""" """
""" """
def __init__(self, in_channels, **kwargs): def __init__(self,
in_channels,
character_dict_path='ppocr/utils/ic15_dict.txt',
**kwargs):
super(PGHead, self).__init__() super(PGHead, self).__init__()
# get character_length
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
character_length = len(lines) + 1
self.conv_f_score1 = ConvBNLayer( self.conv_f_score1 = ConvBNLayer(
in_channels=in_channels, in_channels=in_channels,
out_channels=64, out_channels=64,
...@@ -178,7 +187,7 @@ class PGHead(nn.Layer): ...@@ -178,7 +187,7 @@ class PGHead(nn.Layer):
name="conv_f_char{}".format(5)) name="conv_f_char{}".format(5))
self.conv3 = nn.Conv2D( self.conv3 = nn.Conv2D(
in_channels=256, in_channels=256,
out_channels=37, out_channels=character_length,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
......
...@@ -30,12 +30,18 @@ class PGPostProcess(object): ...@@ -30,12 +30,18 @@ class PGPostProcess(object):
The post process for PGNet. The post process for PGNet.
""" """
def __init__(self, character_dict_path, valid_set, score_thresh, mode, def __init__(self,
character_dict_path,
valid_set,
score_thresh,
mode,
point_gather_mode=None,
**kwargs): **kwargs):
self.character_dict_path = character_dict_path self.character_dict_path = character_dict_path
self.valid_set = valid_set self.valid_set = valid_set
self.score_thresh = score_thresh self.score_thresh = score_thresh
self.mode = mode self.mode = mode
self.point_gather_mode = point_gather_mode
# c++ la-nms is faster, but only support python 3.5 # c++ la-nms is faster, but only support python 3.5
self.is_python35 = False self.is_python35 = False
...@@ -43,8 +49,13 @@ class PGPostProcess(object): ...@@ -43,8 +49,13 @@ class PGPostProcess(object):
self.is_python35 = True self.is_python35 = True
def __call__(self, outs_dict, shape_list): def __call__(self, outs_dict, shape_list):
post = PGNet_PostProcess(self.character_dict_path, self.valid_set, post = PGNet_PostProcess(
self.score_thresh, outs_dict, shape_list) self.character_dict_path,
self.valid_set,
self.score_thresh,
outs_dict,
shape_list,
point_gather_mode=self.point_gather_mode)
if self.mode == 'fast': if self.mode == 'fast':
data = post.pg_postprocess_fast() data = post.pg_postprocess_fast()
else: else:
......
...@@ -88,8 +88,35 @@ def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True): ...@@ -88,8 +88,35 @@ def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
return dst_str, keep_idx_list return dst_str, keep_idx_list
def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4): def instance_ctc_greedy_decoder(gather_info,
logits_map,
pts_num=4,
point_gather_mode=None):
_, _, C = logits_map.shape _, _, C = logits_map.shape
if point_gather_mode == 'align':
insert_num = 0
gather_info = np.array(gather_info)
length = len(gather_info) - 1
for index in range(length):
stride_y = np.abs(gather_info[index + insert_num][0] - gather_info[
index + 1 + insert_num][0])
stride_x = np.abs(gather_info[index + insert_num][1] - gather_info[
index + 1 + insert_num][1])
max_points = int(max(stride_x, stride_y))
stride = (gather_info[index + insert_num] -
gather_info[index + 1 + insert_num]) / (max_points)
insert_num_temp = max_points - 1
for i in range(int(insert_num_temp)):
insert_value = gather_info[index + insert_num] - (i + 1
) * stride
insert_index = index + i + 1 + insert_num
gather_info = np.insert(
gather_info, insert_index, insert_value, axis=0)
insert_num += insert_num_temp
gather_info = gather_info.tolist()
else:
pass
ys, xs = zip(*gather_info) ys, xs = zip(*gather_info)
logits_seq = logits_map[list(ys), list(xs)] logits_seq = logits_map[list(ys), list(xs)]
probs_seq = logits_seq probs_seq = logits_seq
...@@ -104,7 +131,8 @@ def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4): ...@@ -104,7 +131,8 @@ def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4):
def ctc_decoder_for_image(gather_info_list, def ctc_decoder_for_image(gather_info_list,
logits_map, logits_map,
Lexicon_Table, Lexicon_Table,
pts_num=6): pts_num=6,
point_gather_mode=None):
""" """
CTC decoder using multiple processes. CTC decoder using multiple processes.
""" """
...@@ -114,7 +142,10 @@ def ctc_decoder_for_image(gather_info_list, ...@@ -114,7 +142,10 @@ def ctc_decoder_for_image(gather_info_list,
if len(gather_info) < pts_num: if len(gather_info) < pts_num:
continue continue
dst_str, xys_list = instance_ctc_greedy_decoder( dst_str, xys_list = instance_ctc_greedy_decoder(
gather_info, logits_map, pts_num=pts_num) gather_info,
logits_map,
pts_num=pts_num,
point_gather_mode=point_gather_mode)
dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str]) dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
if len(dst_str_readable) < 2: if len(dst_str_readable) < 2:
continue continue
...@@ -356,7 +387,8 @@ def generate_pivot_list_fast(p_score, ...@@ -356,7 +387,8 @@ def generate_pivot_list_fast(p_score,
p_char_maps, p_char_maps,
f_direction, f_direction,
Lexicon_Table, Lexicon_Table,
score_thresh=0.5): score_thresh=0.5,
point_gather_mode=None):
""" """
return center point and end point of TCL instance; filter with the char maps; return center point and end point of TCL instance; filter with the char maps;
""" """
...@@ -384,7 +416,10 @@ def generate_pivot_list_fast(p_score, ...@@ -384,7 +416,10 @@ def generate_pivot_list_fast(p_score,
p_char_maps = p_char_maps.transpose([1, 2, 0]) p_char_maps = p_char_maps.transpose([1, 2, 0])
decoded_str, keep_yxs_list = ctc_decoder_for_image( decoded_str, keep_yxs_list = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table) all_pos_yxs,
logits_map=p_char_maps,
Lexicon_Table=Lexicon_Table,
point_gather_mode=point_gather_mode)
return keep_yxs_list, decoded_str return keep_yxs_list, decoded_str
......
...@@ -28,13 +28,19 @@ from extract_textpoint_fast import generate_pivot_list_fast, restore_poly ...@@ -28,13 +28,19 @@ from extract_textpoint_fast import generate_pivot_list_fast, restore_poly
class PGNet_PostProcess(object): class PGNet_PostProcess(object):
# two different post-process # two different post-process
def __init__(self, character_dict_path, valid_set, score_thresh, outs_dict, def __init__(self,
shape_list): character_dict_path,
valid_set,
score_thresh,
outs_dict,
shape_list,
point_gather_mode=None):
self.Lexicon_Table = get_dict(character_dict_path) self.Lexicon_Table = get_dict(character_dict_path)
self.valid_set = valid_set self.valid_set = valid_set
self.score_thresh = score_thresh self.score_thresh = score_thresh
self.outs_dict = outs_dict self.outs_dict = outs_dict
self.shape_list = shape_list self.shape_list = shape_list
self.point_gather_mode = point_gather_mode
def pg_postprocess_fast(self): def pg_postprocess_fast(self):
p_score = self.outs_dict['f_score'] p_score = self.outs_dict['f_score']
...@@ -58,7 +64,8 @@ class PGNet_PostProcess(object): ...@@ -58,7 +64,8 @@ class PGNet_PostProcess(object):
p_char, p_char,
p_direction, p_direction,
self.Lexicon_Table, self.Lexicon_Table,
score_thresh=self.score_thresh) score_thresh=self.score_thresh,
point_gather_mode=self.point_gather_mode)
poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs, poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs,
p_border, ratio_w, ratio_h, p_border, ratio_w, ratio_h,
src_w, src_h, self.valid_set) src_w, src_h, self.valid_set)
......
...@@ -37,6 +37,46 @@ from ppocr.postprocess import build_post_process ...@@ -37,6 +37,46 @@ from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
import tools.program as program import tools.program as program
from PIL import Image, ImageDraw, ImageFont
import math
def draw_e2e_res_for_chinese(image,
boxes,
txts,
config,
img_name,
font_path="./doc/simfang.ttf"):
h, w = image.height, image.width
img_left = image.copy()
img_right = Image.new('RGB', (w, h), (255, 255, 255))
import random
random.seed(0)
draw_left = ImageDraw.Draw(img_left)
draw_right = ImageDraw.Draw(img_right)
for idx, (box, txt) in enumerate(zip(boxes, txts)):
box = np.array(box)
box = [tuple(x) for x in box]
color = (random.randint(0, 255), random.randint(0, 255),
random.randint(0, 255))
draw_left.polygon(box, fill=color)
draw_right.polygon(box, outline=color)
font = ImageFont.truetype(font_path, 15, encoding="utf-8")
draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
img_left = Image.blend(image, img_left, 0.5)
img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
img_show.paste(img_left, (0, 0, w, h))
img_show.paste(img_right, (w, 0, w * 2, h))
save_e2e_path = os.path.dirname(config['Global'][
'save_res_path']) + "/e2e_results/"
if not os.path.exists(save_e2e_path):
os.makedirs(save_e2e_path)
save_path = os.path.join(save_e2e_path, os.path.basename(img_name))
cv2.imwrite(save_path, np.array(img_show)[:, :, ::-1])
logger.info("The e2e Image saved in {}".format(save_path))
def draw_e2e_res(dt_boxes, strs, config, img, img_name): def draw_e2e_res(dt_boxes, strs, config, img, img_name):
...@@ -113,7 +153,19 @@ def main(): ...@@ -113,7 +153,19 @@ def main():
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n" otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
fout.write(otstr.encode()) fout.write(otstr.encode())
src_img = cv2.imread(file) src_img = cv2.imread(file)
draw_e2e_res(points, strs, config, src_img, file) if global_config['infer_visual_type'] == 'EN':
draw_e2e_res(points, strs, config, src_img, file)
elif global_config['infer_visual_type'] == 'CN':
src_img = Image.fromarray(
cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB))
draw_e2e_res_for_chinese(
src_img,
points,
strs,
config,
file,
font_path="./doc/fonts/simfang.ttf")
logger.info("success!") logger.info("success!")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册