提交 1f76f449 编写于 作者: J Jethong

Add PGNet

上级 1a087990
......@@ -20,6 +20,7 @@ Global:
infer_img:
save_res_path: ./output/sast_r50_vd_ic15/predicts_sast.txt
Architecture:
model_type: det
algorithm: SAST
......
Global:
use_gpu: False
epoch_num: 600
log_smooth_window: 20
print_batch_step: 2
save_model_dir: ./output/pg_r50_vd_tt/
save_epoch_step: 1
# evaluation is run every 5000 iterationss after the 4000th iteration
eval_batch_step: [ 0, 1000 ]
# if pretrained_model is saved in static mode, load_static_weights must set to True
load_static_weights: False
cal_metric_during_train: False
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img:
save_res_path: ./output/pg_r50_vd_tt/predicts_pg.txt
Architecture:
model_type: e2e
algorithm: PG
Transform:
Backbone:
name: ResNet
layers: 50
Neck:
name: PGFPN
model_name: large
Head:
name: PGHead
model_name: large
Loss:
name: PGLoss
#Optimizer:
# name: Adam
# beta1: 0.9
# beta2: 0.999
# lr:
# name: Cosine
# learning_rate: 0.001
# warmup_epoch: 1
# regularizer:
# name: 'L2'
# factor: 0
Optimizer:
name: RMSProp
lr:
name: Piecewise
learning_rate: 0.001
decay_epochs: [ 40, 80, 120, 160, 200 ]
values: [ 0.001, 0.00033, 0.0001, 0.000033, 0.00001 ]
regularizer:
name: 'L2'
factor: 0.00005
PostProcess:
name: PGPostProcess
score_thresh: 0.8
cover_thresh: 0.1
nms_thresh: 0.2
Metric:
name: E2EMetric
main_indicator: hmean
Train:
dataset:
name: PGDateSet
label_file_list:
ratio_list:
data_format: textnet # textnet/partvgg
Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- PGProcessTrain:
batch_size: 14
data_format: icdar
tcl_len: 64
min_crop_size: 24
min_text_size: 4
max_text_size: 512
- KeepKeys:
keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order
loader:
shuffle: True
drop_last: True
batch_size_per_card: 1
num_workers: 8
Eval:
dataset:
name: PGDateSet
data_dir: ./train_data/
label_file_list:
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- E2ELabelEncode:
label_list: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
- E2EResizeForTest:
valid_set: totaltext
max_side_len: 768
- NormalizeImage:
scale: 1./255.
mean: [ 0.485, 0.456, 0.406 ]
std: [ 0.229, 0.224, 0.225 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags' ]
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
\ No newline at end of file
......@@ -34,6 +34,7 @@ import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet
from ppocr.data.pgnet_dataset import PGDateSet
__all__ = ['build_dataloader', 'transform', 'create_operators']
......@@ -54,7 +55,8 @@ signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDataSet']
support_dict = ['SimpleDataSet', 'LMDBDateSet', 'PGDateSet']
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict))
......
......@@ -28,6 +28,7 @@ from .label_ops import *
from .east_process import *
from .sast_process import *
from .pg_process import *
def transform(data, ops=None):
......
......@@ -34,6 +34,25 @@ class ClsLabelEncode(object):
return data
class E2ELabelEncode(object):
def __init__(self, label_list, **kwargs):
self.label_list = label_list
def __call__(self, data):
text_label_index_list, temp_text = [], []
texts = data['strs']
for text in texts:
text = text.upper()
temp_text = []
for c_ in text:
if c_ in self.label_list:
temp_text.append(self.label_list.index(c_))
temp_text = temp_text + [36] * (50 - len(temp_text))
text_label_index_list.append(temp_text)
data['strs'] = np.array(text_label_index_list)
return data
class DetLabelEncode(object):
def __init__(self, **kwargs):
pass
......
......@@ -223,3 +223,74 @@ class DetResizeForTest(object):
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
class E2EResizeForTest(object):
def __init__(self, **kwargs):
super(E2EResizeForTest, self).__init__()
self.max_side_len = kwargs['max_side_len']
self.valid_set = kwargs['valid_set']
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if self.valid_set == 'totaltext':
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
img, max_side_len=self.max_side_len)
else:
im_resized, (ratio_h, ratio_w) = self.resize_image(
img, max_side_len=self.max_side_len)
data['image'] = im_resized
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_for_totaltext(self, im, max_side_len=512):
"""
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image(self, im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
此差异已折叠。
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
from paddle.io import Dataset
from .imaug import transform, create_operators
import random
class PGDateSet(Dataset):
def __init__(self, config, mode, logger):
super(PGDateSet, self).__init__()
self.logger = logger
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
self.data_format = dataset_config.get('data_format', 'icdar')
assert len(
ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list."
# self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list,
self.data_format)
self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train":
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
def shuffle_data_random(self):
if self.do_shuffle:
random.shuffle(self.data_lines)
return
def extract_polys(self, poly_txt_path):
"""
Read text_polys, txt_tags, txts from give txt file.
"""
text_polys, txt_tags, txts = [], [], []
with open(poly_txt_path) as f:
for line in f.readlines():
poly_str, txt = line.strip().split('\t')
poly = map(float, poly_str.split(','))
text_polys.append(
np.array(
list(poly), dtype=np.float32).reshape(-1, 2))
txts.append(txt)
if txt == '###':
txt_tags.append(True)
else:
txt_tags.append(False)
return np.array(list(map(np.array, text_polys))), \
np.array(txt_tags, dtype=np.bool), txts
def extract_info_textnet(self, im_fn, img_dir=''):
"""
Extract information from line in textnet format.
"""
info_list = im_fn.split('\t')
img_path = ''
for ext in ['.jpg', '.png', '.jpeg', '.JPG']:
if os.path.exists(os.path.join(img_dir, info_list[0] + ext)):
img_path = os.path.join(img_dir, info_list[0] + ext)
break
if img_path == '':
print('Image {0} NOT found in {1}, and it will be ignored.'.format(
info_list[0], img_dir))
nBox = (len(info_list) - 1) // 9
wordBBs, txts, txt_tags = [], [], []
for n in range(0, nBox):
wordBB = list(map(float, info_list[n * 9 + 1:(n + 1) * 9]))
txt = info_list[(n + 1) * 9]
wordBBs.append([[wordBB[0], wordBB[1]], [wordBB[2], wordBB[3]],
[wordBB[4], wordBB[5]], [wordBB[6], wordBB[7]]])
txts.append(txt)
if txt == '###':
txt_tags.append(True)
else:
txt_tags.append(False)
return img_path, np.array(wordBBs, dtype=np.float32), txt_tags, txts
def get_image_info_list(self, file_list, ratio_list, data_format='textnet'):
if isinstance(file_list, str):
file_list = [file_list]
data_lines = []
for idx, data_source in enumerate(file_list):
image_files = []
if data_format == 'icdar':
image_files = [
(data_source, x)
for x in os.listdir(os.path.join(data_source, 'rgb'))
if x.split('.')[-1] in ['jpg', 'png', 'jpeg', 'JPG']
]
elif data_format == 'textnet':
with open(data_source) as f:
image_files = [(data_source, x.strip())
for x in f.readlines()]
else:
print("Unrecognized data format...")
exit(-1)
image_files = random.sample(
image_files, round(len(image_files) * ratio_list[idx]))
data_lines.extend(image_files)
return data_lines
def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx]
data_path, data_line = self.data_lines[file_idx]
try:
if self.data_format == 'icdar':
im_path = os.path.join(data_path, 'rgb', data_line)
poly_path = os.path.join(data_path, 'poly',
data_line.split('.')[0] + '.txt')
text_polys, text_tags, text_strs = self.extract_polys(poly_path)
else:
image_dir = os.path.join(os.path.dirname(data_path), 'image')
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet(
data_line, image_dir)
data = {
'img_path': im_path,
'polys': text_polys,
'tags': text_tags,
'strs': text_strs
}
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
except Exception as e:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
self.data_idx_order_list[idx], e))
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_idx_order_list)
......@@ -29,10 +29,11 @@ def build_loss(config):
# cls loss
from .cls_loss import ClsLoss
# e2e loss
from .e2e_pg_loss import PGLoss
support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss'
]
'SRNLoss', 'PGLoss']
config = copy.deepcopy(config)
module_name = config.pop('name')
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
import paddle
import numpy as np
import copy
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
class PGLoss(nn.Layer):
"""
Differentiable Binarization (DB) Loss Function
args:
param (dict): the super paramter for DB Loss
"""
def __init__(self, alpha=5, beta=10, eps=1e-6, **kwargs):
super(PGLoss, self).__init__()
self.alpha = alpha
self.beta = beta
self.dice_loss = DiceLoss(eps=eps)
def org_tcl_rois(self, batch_size, pos_lists, pos_masks, label_lists):
"""
"""
pos_lists_, pos_masks_, label_lists_ = [], [], []
img_bs = batch_size
tcl_bs = 64
ngpu = int(batch_size / img_bs)
img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy()
pos_lists_split, pos_masks_split, label_lists_split = [], [], []
for i in range(ngpu):
pos_lists_split.append([])
pos_masks_split.append([])
label_lists_split.append([])
for i in range(img_ids.shape[0]):
img_id = img_ids[i]
gpu_id = int(img_id / img_bs)
img_id = img_id % img_bs
pos_list = pos_lists[i].copy()
pos_list[:, 0] = img_id
pos_lists_split[gpu_id].append(pos_list)
pos_masks_split[gpu_id].append(pos_masks[i].copy())
label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i]))
# repeat or delete
for i in range(ngpu):
vp_len = len(pos_lists_split[i])
if vp_len <= tcl_bs:
for j in range(0, tcl_bs - vp_len):
pos_list = pos_lists_split[i][j].copy()
pos_lists_split[i].append(pos_list)
pos_mask = pos_masks_split[i][j].copy()
pos_masks_split[i].append(pos_mask)
label_list = copy.deepcopy(label_lists_split[i][j])
label_lists_split[i].append(label_list)
else:
for j in range(0, vp_len - tcl_bs):
c_len = len(pos_lists_split[i])
pop_id = np.random.permutation(c_len)[0]
pos_lists_split[i].pop(pop_id)
pos_masks_split[i].pop(pop_id)
label_lists_split[i].pop(pop_id)
# merge
for i in range(ngpu):
pos_lists_.extend(pos_lists_split[i])
pos_masks_.extend(pos_masks_split[i])
label_lists_.extend(label_lists_split[i])
return pos_lists_, pos_masks_, label_lists_
def pre_process(self, label_list, pos_list, pos_mask):
label_list = label_list.numpy()
b, h, w, c = label_list.shape
pos_list = pos_list.numpy()
pos_mask = pos_mask.numpy()
pos_list_t = []
pos_mask_t = []
label_list_t = []
for i in range(b):
for j in range(30):
if pos_mask[i, j].any():
pos_list_t.append(pos_list[i][j])
pos_mask_t.append(pos_mask[i][j])
label_list_t.append(label_list[i][j])
pos_list, pos_mask, label_list = self.org_tcl_rois(
b, pos_list_t, pos_mask_t, label_list_t)
label = []
tt = [l.tolist() for l in label_list]
for i in range(64):
k = 0
for j in range(50):
if tt[i][j][0] != 36:
k += 1
else:
break
label.append(k)
label = paddle.to_tensor(label)
label = paddle.cast(label, dtype='int64')
pos_list = paddle.to_tensor(pos_list)
pos_mask = paddle.to_tensor(pos_mask)
label_list = paddle.squeeze(paddle.to_tensor(label_list), axis=2)
label_list = paddle.cast(label_list, dtype='int32')
return pos_list, pos_mask, label_list, label
def border_loss(self, f_border, l_border, l_score, l_mask):
l_border_split, l_border_norm = paddle.tensor.split(
l_border, num_or_sections=[4, 1], axis=1)
f_border_split = f_border
b, c, h, w = l_border_norm.shape
l_border_norm_split = paddle.expand(
x=l_border_norm, shape=[b, 4 * c, h, w])
b, c, h, w = l_score.shape
l_border_score = paddle.expand(x=l_score, shape=[b, 4 * c, h, w])
b, c, h, w = l_mask.shape
l_border_mask = paddle.expand(x=l_mask, shape=[b, 4 * c, h, w])
border_diff = l_border_split - f_border_split
abs_border_diff = paddle.abs(border_diff)
border_sign = abs_border_diff < 1.0
border_sign = paddle.cast(border_sign, dtype='float32')
border_sign.stop_gradient = True
border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
(abs_border_diff - 0.5) * (1.0 - border_sign)
border_out_loss = l_border_norm_split * border_in_loss
border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
(paddle.sum(l_border_score * l_border_mask) + 1e-5)
return border_loss
def direction_loss(self, f_direction, l_direction, l_score, l_mask):
l_direction_split, l_direction_norm = paddle.tensor.split(
l_direction, num_or_sections=[2, 1], axis=1)
f_direction_split = f_direction
b, c, h, w = l_direction_norm.shape
l_direction_norm_split = paddle.expand(
x=l_direction_norm, shape=[b, 2 * c, h, w])
b, c, h, w = l_score.shape
l_direction_score = paddle.expand(x=l_score, shape=[b, 2 * c, h, w])
b, c, h, w = l_mask.shape
l_direction_mask = paddle.expand(x=l_mask, shape=[b, 2 * c, h, w])
direction_diff = l_direction_split - f_direction_split
abs_direction_diff = paddle.abs(direction_diff)
direction_sign = abs_direction_diff < 1.0
direction_sign = paddle.cast(direction_sign, dtype='float32')
direction_sign.stop_gradient = True
direction_in_loss = 0.5 * abs_direction_diff * abs_direction_diff * direction_sign + \
(abs_direction_diff - 0.5) * (1.0 - direction_sign)
direction_out_loss = l_direction_norm_split * direction_in_loss
direction_loss = paddle.sum(direction_out_loss * l_direction_score * l_direction_mask) / \
(paddle.sum(l_direction_score * l_direction_mask) + 1e-5)
return direction_loss
def ctcloss(self, f_char, tcl_pos, tcl_mask, tcl_label, label_t):
f_char = paddle.transpose(f_char, [0, 2, 3, 1])
tcl_pos = paddle.reshape(tcl_pos, [-1, 3])
tcl_pos = paddle.cast(tcl_pos, dtype=int)
f_tcl_char = paddle.gather_nd(f_char, tcl_pos)
f_tcl_char = paddle.reshape(f_tcl_char,
[-1, 64, 37]) # len(Lexicon_Table)+1
f_tcl_char_fg, f_tcl_char_bg = paddle.split(f_tcl_char, [36, 1], axis=2)
f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0
b, c, l = tcl_mask.shape
tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, 36 * l])
tcl_mask_fg.stop_gradient = True
f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * (
-20.0)
f_tcl_char_mask = paddle.concat([f_tcl_char_fg, f_tcl_char_bg], axis=2)
f_tcl_char_ld = paddle.transpose(f_tcl_char_mask, (1, 0, 2))
N, B, _ = f_tcl_char_ld.shape
input_lengths = paddle.to_tensor([N] * B, dtype='int64')
cost = paddle.nn.functional.ctc_loss(
log_probs=f_tcl_char_ld,
labels=tcl_label,
input_lengths=input_lengths,
label_lengths=label_t,
blank=36,
reduction='none')
cost = cost.mean()
return cost
def forward(self, predicts, labels):
images, tcl_maps, tcl_label_maps, border_maps \
, direction_maps, training_masks, label_list, pos_list, pos_mask = labels
# for all the batch_size
pos_list, pos_mask, label_list, label_t = self.pre_process(
label_list, pos_list, pos_mask)
f_score, f_boder, f_direction, f_char = predicts
score_loss = self.dice_loss(f_score, tcl_maps, training_masks)
border_loss = self.border_loss(f_boder, border_maps, tcl_maps,
training_masks)
direction_loss = self.direction_loss(f_direction, direction_maps,
tcl_maps, training_masks)
ctc_loss = self.ctcloss(f_char, pos_list, pos_mask, label_list, label_t)
loss_all = score_loss + border_loss + direction_loss + 5 * ctc_loss
losses = {
'loss': loss_all,
"score_loss": score_loss,
"border_loss": border_loss,
"direction_loss": direction_loss,
"ctc_loss": ctc_loss
}
return losses
......@@ -26,8 +26,9 @@ def build_metric(config):
from .det_metric import DetMetric
from .rec_metric import RecMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
support_dict = ['DetMetric', 'RecMetric', 'ClsMetric']
support_dict = ['DetMetric', 'RecMetric', 'ClsMetric', 'E2EMetric']
config = copy.deepcopy(config)
module_name = config.pop('name')
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
__all__ = ['E2EMetric']
from ppocr.utils.e2e_metric.Deteval import *
class E2EMetric(object):
def __init__(self, main_indicator='f_score_e2e', **kwargs):
self.label_list = [
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C',
'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'
]
self.main_indicator = main_indicator
self.reset()
def __call__(self, preds, batch, **kwargs):
'''
batch: a list produced by dataloaders.
image: np.ndarray of shape (N, C, H, W).
ratio_list: np.ndarray of shape(N,2)
polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
preds: a list of dict produced by post process
points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
'''
gt_polyons_batch = batch[2]
temp_gt_strs_batch = batch[3]
ignore_tags_batch = batch[4]
gt_strs_batch = []
temp_gt_strs_batch = temp_gt_strs_batch[0].tolist()
for temp_list in temp_gt_strs_batch:
t = ""
for index in temp_list:
if index < 36:
t += self.label_list[index]
gt_strs_batch.append(t)
for pred, gt_polyons, gt_strs, ignore_tags in zip(
preds, gt_polyons_batch, gt_strs_batch, ignore_tags_batch):
# prepare gt
gt_info_list = [{
'points': gt_polyon,
'text': gt_str,
'ignore': ignore_tag
} for gt_polyon, gt_str, ignore_tag in
zip(gt_polyons, gt_strs, ignore_tags)]
# prepare det
e2e_info_list = [{
'points': det_polyon,
'text': pred_str
} for det_polyon, pred_str in zip(pred['points'], preds['strs'])]
result = get_socre(gt_info_list, e2e_info_list)
self.results.append(result)
def get_metric(self):
"""
return metrics {
'precision': 0,
'recall': 0,
'hmean': 0
}
"""
metircs = combine_results(self.results)
self.reset()
return metircs
def reset(self):
self.results = [] # clear results
......@@ -200,7 +200,8 @@ class DetectionIoUEvaluator(object):
methodPrecision = 0 if numGlobalCareDet == 0 else float(
matchedSum) / numGlobalCareDet
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
methodRecall * methodPrecision / (methodRecall + methodPrecision)
methodRecall * methodPrecision / (
methodRecall + methodPrecision)
# print(methodRecall, methodPrecision, methodHmean)
# sys.exit(-1)
methodMetrics = {
......
......@@ -26,6 +26,9 @@ def build_backbone(config, model_type):
from .rec_resnet_vd import ResNet
from .rec_resnet_fpn import ResNetFPN
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
support_dict = ['ResNet']
else:
raise NotImplementedError
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
__all__ = ["ResNet"]
class ConvBNLayer(nn.Layer):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
is_vd_mode=False,
act=None,
name=None, ):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = nn.BatchNorm(
out_channels,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def forward(self, inputs):
# if self.is_vd_mode:
# inputs = self._pool2d_avg(inputs)
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class BottleneckBlock(nn.Layer):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels * 4,
kernel_size=1,
stride=stride,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv2)
y = F.relu(y)
return y
class BasicBlock(nn.Layer):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv1)
y = F.relu(y)
return y
class ResNet(nn.Layer):
def __init__(self, in_channels=3, layers=50, **kwargs):
super(ResNet, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
# depth = [3, 4, 6, 3]
depth = [3, 4, 6, 3, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_channels = [64, 256, 512, 1024,
2048] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512, 512]
self.conv1_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
kernel_size=7,
stride=2,
act='relu',
name="conv1_1")
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stages = []
self.out_channels = [3, 64]
# num_filters = [64, 128, 256, 512, 512]
if layers >= 50:
for block in range(len(depth)):
block_list = []
shortcut = False
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name))
shortcut = True
block_list.append(bottleneck_block)
self.out_channels.append(num_filters[block] * 4)
self.stages.append(nn.Sequential(*block_list))
else:
for block in range(len(depth)):
block_list = []
shortcut = False
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BasicBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block],
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name))
shortcut = True
block_list.append(basic_block)
self.out_channels.append(num_filters[block])
self.stages.append(nn.Sequential(*block_list))
def forward(self, inputs):
out = [inputs]
y = self.conv1_1(inputs)
out.append(y)
y = self.pool2d_max(y)
for block in self.stages:
y = block(y)
out.append(y)
return out
......@@ -20,6 +20,7 @@ def build_head(config):
from .det_db_head import DBHead
from .det_east_head import EASTHead
from .det_sast_head import SASTHead
from .e2e_pg_head import PGHead
# rec head
from .rec_ctc_head import CTCHead
......@@ -30,8 +31,8 @@ def build_head(config):
from .cls_head import ClsHead
support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
'SRNHead'
]
'SRNHead', 'PGHead']
module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format(
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=1,
if_act=True,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
weight_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=act,
param_attr=ParamAttr(name="bn_" + name + "_scale"),
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
moving_variance_name="bn_" + name + "_variance",
use_global_stats=False)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class PGHead(nn.Layer):
"""
"""
def __init__(self, in_channels, model_name, **kwargs):
super(PGHead, self).__init__()
self.model_name = model_name
self.conv_f_score1 = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_score{}".format(1))
self.conv_f_score2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_score{}".format(2))
self.conv_f_score3 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_score{}".format(3))
self.conv1 = nn.Conv2D(
in_channels=128,
out_channels=1,
kernel_size=3,
stride=1,
padding=1,
groups=1,
weight_attr=ParamAttr(name="conv_f_score{}".format(4)),
bias_attr=False)
self.conv_f_boder1 = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_boder{}".format(1))
self.conv_f_boder2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_boder{}".format(2))
self.conv_f_boder3 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_boder{}".format(3))
self.conv2 = nn.Conv2D(
in_channels=128,
out_channels=4,
kernel_size=3,
stride=1,
padding=1,
groups=1,
weight_attr=ParamAttr(name="conv_f_boder{}".format(4)),
bias_attr=False)
self.conv_f_char1 = ConvBNLayer(
in_channels=in_channels,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_char{}".format(1))
self.conv_f_char2 = ConvBNLayer(
in_channels=128,
out_channels=128,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_char{}".format(2))
self.conv_f_char3 = ConvBNLayer(
in_channels=128,
out_channels=256,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_char{}".format(3))
self.conv_f_char4 = ConvBNLayer(
in_channels=256,
out_channels=256,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_char{}".format(4))
self.conv_f_char5 = ConvBNLayer(
in_channels=256,
out_channels=256,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_char{}".format(5))
self.conv3 = nn.Conv2D(
in_channels=256,
out_channels=6625,
kernel_size=3,
stride=1,
padding=1,
groups=1,
weight_attr=ParamAttr(name="conv_f_char{}".format(6)),
bias_attr=False)
self.conv_f_direc1 = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_direc{}".format(1))
self.conv_f_direc2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_direc{}".format(2))
self.conv_f_direc3 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_direc{}".format(3))
self.conv4 = nn.Conv2D(
in_channels=128,
out_channels=2,
kernel_size=3,
stride=1,
padding=1,
groups=1,
weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
bias_attr=False)
def forward(self, x):
f_score = self.conv_f_score1(x)
f_score = self.conv_f_score2(f_score)
f_score = self.conv_f_score3(f_score)
f_score = self.conv1(f_score)
f_score = F.sigmoid(f_score)
# f_boder
f_boder = self.conv_f_boder1(x)
f_boder = self.conv_f_boder2(f_boder)
f_boder = self.conv_f_boder3(f_boder)
f_boder = self.conv2(f_boder)
f_char = self.conv_f_char1(x)
f_char = self.conv_f_char2(f_char)
f_char = self.conv_f_char3(f_char)
f_char = self.conv_f_char4(f_char)
f_char = self.conv_f_char5(f_char)
f_char = self.conv3(f_char)
f_direction = self.conv_f_direc1(x)
f_direction = self.conv_f_direc2(f_direction)
f_direction = self.conv_f_direc3(f_direction)
f_direction = self.conv4(f_direction)
return f_score, f_boder, f_direction, f_char
......@@ -14,12 +14,14 @@
__all__ = ['build_neck']
def build_neck(config):
from .db_fpn import DBFPN
from .east_fpn import EASTFPN
from .sast_fpn import SASTFPN
from .rnn import SequenceEncoder
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder']
from .pg_fpn import PGFPN
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN']
module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format(
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
is_vd_mode=False,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = nn.BatchNorm(
out_channels,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',
use_global_stats=False)
def forward(self, inputs):
# if self.is_vd_mode:
# inputs = self._pool2d_avg(inputs)
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class DeConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size=4,
stride=2,
padding=1,
groups=1,
if_act=True,
act=None,
name=None):
super(DeConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.deconv = nn.Conv2DTranspose(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
weight_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=act,
param_attr=ParamAttr(name="bn_" + name + "_scale"),
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
moving_variance_name="bn_" + name + "_variance",
use_global_stats=False)
def forward(self, x):
x = self.deconv(x)
x = self.bn(x)
return x
class FPN_Up_Fusion(nn.Layer):
def __init__(self, in_channels):
super(FPN_Up_Fusion, self).__init__()
in_channels = in_channels[::-1]
out_channels = [256, 256, 192, 192, 128]
self.h0_conv = ConvBNLayer(
in_channels[0], out_channels[0], 1, 1, act=None, name='conv_h0')
self.h1_conv = ConvBNLayer(
in_channels[1], out_channels[1], 1, 1, act=None, name='conv_h1')
self.h2_conv = ConvBNLayer(
in_channels[2], out_channels[2], 1, 1, act=None, name='conv_h2')
self.h3_conv = ConvBNLayer(
in_channels[3], out_channels[3], 1, 1, act=None, name='conv_h3')
self.h4_conv = ConvBNLayer(
in_channels[4], out_channels[4], 1, 1, act=None, name='conv_h4')
self.dconv0 = DeConvBNLayer(
in_channels=out_channels[0],
out_channels=out_channels[1],
name="dconv_{}".format(0))
self.dconv1 = DeConvBNLayer(
in_channels=out_channels[1],
out_channels=out_channels[2],
act=None,
name="dconv_{}".format(1))
self.dconv2 = DeConvBNLayer(
in_channels=out_channels[2],
out_channels=out_channels[3],
act=None,
name="dconv_{}".format(2))
self.dconv3 = DeConvBNLayer(
in_channels=out_channels[3],
out_channels=out_channels[4],
act=None,
name="dconv_{}".format(3))
self.conv_g1 = ConvBNLayer(
in_channels=out_channels[1],
out_channels=out_channels[1],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(1))
self.conv_g2 = ConvBNLayer(
in_channels=out_channels[2],
out_channels=out_channels[2],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(2))
self.conv_g3 = ConvBNLayer(
in_channels=out_channels[3],
out_channels=out_channels[3],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(3))
self.conv_g4 = ConvBNLayer(
in_channels=out_channels[4],
out_channels=out_channels[4],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(4))
self.convf = ConvBNLayer(
in_channels=out_channels[4],
out_channels=out_channels[4],
kernel_size=1,
stride=1,
act=None,
name="conv_f{}".format(4))
def _add_relu(self, x1, x2):
x = paddle.add(x=x1, y=x2)
x = F.relu(x)
return x
def forward(self, x):
f = x[2:][::-1]
h0 = self.h0_conv(f[0])
h1 = self.h1_conv(f[1])
h2 = self.h2_conv(f[2])
h3 = self.h3_conv(f[3])
h4 = self.h4_conv(f[4])
g0 = self.dconv0(h0)
g1 = self.dconv2(self.conv_g2(self._add_relu(g0, h1)))
g2 = self.dconv2(self.conv_g2(self._add_relu(g1, h2)))
g3 = self.dconv3(self.conv_g2(self._add_relu(g2, h3)))
g4 = self.dconv4(self.conv_g2(self._add_relu(g3, h4)))
return g4
class FPN_Down_Fusion(nn.Layer):
def __init__(self, in_channels):
super(FPN_Down_Fusion, self).__init__()
out_channels = [32, 64, 128]
self.h0_conv = ConvBNLayer(
in_channels[0], out_channels[0], 3, 1, act=None, name='FPN_d1')
self.h1_conv = ConvBNLayer(
in_channels[1], out_channels[1], 3, 1, act=None, name='FPN_d2')
self.h2_conv = ConvBNLayer(
in_channels[2], out_channels[2], 3, 1, act=None, name='FPN_d3')
self.g0_conv = ConvBNLayer(
out_channels[0], out_channels[1], 3, 2, act=None, name='FPN_d4')
self.g1_conv = nn.Sequential(
ConvBNLayer(
out_channels[1],
out_channels[1],
3,
1,
act='relu',
name='FPN_d5'),
ConvBNLayer(
out_channels[1], out_channels[2], 3, 2, act=None,
name='FPN_d6'))
self.g2_conv = nn.Sequential(
ConvBNLayer(
out_channels[2],
out_channels[2],
3,
1,
act='relu',
name='FPN_d7'),
ConvBNLayer(
out_channels[2], out_channels[2], 1, 1, act=None,
name='FPN_d8'))
def forward(self, x):
f = x[:3]
h0 = self.h0_conv(f[0])
h1 = self.h1_conv(f[1])
h2 = self.h2_conv(f[2])
g0 = self.g0_conv(h0)
g1 = paddle.add(x=g0, y=h1)
g1 = F.relu(g1)
g1 = self.g1_conv(g1)
g2 = paddle.add(x=g1, y=h2)
g2 = F.relu(g2)
g2 = self.g2_conv(g2)
return g2
class PGFPN(nn.Layer):
def __init__(self, in_channels, with_cab=False, **kwargs):
super(PGFPN, self).__init__()
self.in_channels = in_channels
self.with_cab = with_cab
self.FPN_Down_Fusion = FPN_Down_Fusion(self.in_channels)
self.FPN_Up_Fusion = FPN_Up_Fusion(self.in_channels)
self.out_channels = 128
def forward(self, x):
# down fpn
f_down = self.FPN_Down_Fusion(x)
# up fpn
f_up = self.FPN_Up_Fusion(x)
# fusion
f_common = paddle.add(x=f_down, y=f_up)
f_common = F.relu(f_common)
return f_common
......@@ -28,10 +28,11 @@ def build_post_process(config, global_config=None):
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode'
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess'
]
config = copy.deepcopy(config)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
import numpy as np
from .locality_aware_nms import nms_locality
from ppocr.utils.e2e_utils.extract_textpoint import *
from ppocr.utils.e2e_utils.ski_thin import *
from ppocr.utils.e2e_utils.visual import *
import paddle
import cv2
import time
class PGPostProcess(object):
"""
The post process for SAST.
"""
def __init__(self,
score_thresh=0.5,
nms_thresh=0.2,
sample_pts_num=2,
shrink_ratio_of_width=0.3,
expand_scale=1.0,
tcl_map_thresh=0.5,
**kwargs):
self.result_path = ""
self.valid_set = 'totaltext'
self.Lexicon_Table = [
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C',
'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'
]
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.sample_pts_num = sample_pts_num
self.shrink_ratio_of_width = shrink_ratio_of_width
self.expand_scale = expand_scale
self.tcl_map_thresh = tcl_map_thresh
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
if sys.version_info.major == 3 and sys.version_info.minor == 5:
self.is_python35 = True
def __call__(self, outs_dict, shape_list):
p_score, p_border, p_direction, p_char = outs_dict[:4]
p_score = p_score[0].numpy()
p_border = p_border[0].numpy()
p_direction = p_direction[0].numpy()
p_char = p_char[0].numpy()
src_h, src_w, ratio_h, ratio_w = shape_list[0]
if self.valid_set != 'totaltext':
is_curved = False
else:
is_curved = True
instance_yxs_list = generate_pivot_list(
p_score,
p_char,
p_direction,
score_thresh=self.score_thresh,
is_backbone=True,
is_curved=is_curved)
p_char = np.expand_dims(p_char, axis=0)
p_char = paddle.to_tensor(p_char)
char_seq_idx_set = []
for i in range(len(instance_yxs_list)):
gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
featyre_seq = paddle.gather_nd(f_char_map, gather_info_lod)
featyre_seq = np.expand_dims(featyre_seq.numpy(), axis=0)
t = len(featyre_seq[0])
featyre_seq = paddle.to_tensor(featyre_seq)
l = np.array([[t]]).astype(np.int64)
length = paddle.to_tensor(l)
seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
input=featyre_seq, blank=36, input_length=length)
seq_pred1 = seq_pred[0].numpy().tolist()[0]
seq_len = seq_pred[1].numpy()[0][0]
temp_t = []
for x in seq_pred1[:seq_len]:
temp_t.append(x)
char_seq_idx_set.append(temp_t)
seq_strs = []
for char_idx_set in char_seq_idx_set:
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
seq_strs.append(pr_str)
poly_list = []
keep_str_list = []
all_point_list = []
all_point_pair_list = []
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
if len(yx_center_line) == 1:
print('the length of tcl point is less than 2, repeat')
yx_center_line.append(yx_center_line[-1])
# expand corresponding offset for total-text.
offset_expand = 1.0
if self.valid_set == 'totaltext':
offset_expand = 1.2
point_pair_list = []
for batch_id, y, x in yx_center_line:
offset = p_border[:, y, x].reshape(2, 2)
if offset_expand != 1.0:
offset_length = np.linalg.norm(
offset, axis=1, keepdims=True)
expand_length = np.clip(
offset_length * (offset_expand - 1),
a_min=0.5,
a_max=3.0)
offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
# for visualization
all_point_list.append([
int(round(x * 4.0 / ratio_w)),
int(round(y * 4.0 / ratio_h))
])
all_point_pair_list.append(point_pair.round().astype(np.int32)
.tolist())
# ndarry: (x, 2)
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
print('expand along width. {}'.format(detected_poly.shape))
detected_poly = expand_poly_along_width(
detected_poly, shrink_ratio_of_width=0.2)
detected_poly[:, 0] = np.clip(
detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(
detected_poly[:, 1], a_min=0, a_max=src_h)
if len(keep_str) < 2:
print('--> too short, {}'.format(keep_str))
continue
keep_str_list.append(keep_str)
if self.valid_set == 'partvgg':
middle_point = len(detected_poly) // 2
detected_poly = detected_poly[
[0, middle_point - 1, middle_point, -1], :]
poly_list.append(detected_poly)
elif self.valid_set == 'totaltext':
poly_list.append(detected_poly)
else:
print('--> Not supported format.')
exit(-1)
data = {
'points': poly_list,
'strs': keep_str_list,
}
# visualization
# if self.save_visualization:
# visualize_e2e_result(im_fn, poly_list, keep_str_list, src_im)
# visualize_point_result(im_fn, all_point_list, all_point_pair_list, src_im)
# save detected boxes
# txt_dir = (result_path[:-1] if result_path.endswith('/') else result_path) + '_txt_anno'
# if not os.path.exists(txt_dir):
# os.makedirs(txt_dir)
# res_file = os.path.join(txt_dir, '{}.txt'.format(im_prefix))
# with open(res_file, 'w') as f:
# for i_box, box in enumerate(poly_list):
# seq_str = keep_str_list[i_box]
# box = np.round(box).astype('int32')
# box_str = ','.join(str(s) for s in (box.flatten().tolist()))
# f.write('{}\t{}\r\n'.format(box_str, seq_str))
return data
......@@ -18,6 +18,7 @@ from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
......@@ -67,11 +68,15 @@ class SASTPostProcess(object):
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2)
def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.):
def shrink_quad_along_width(self,
quad,
begin_width_ratio=0.,
end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
......@@ -81,16 +86,23 @@ class SASTPostProcess(object):
expand poly along width.
"""
point_num = poly.shape[0]
left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array([poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]], dtype=np.float32)
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio,
1.0)
right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + \
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio)
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0,
right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
......@@ -121,12 +133,10 @@ class SASTPostProcess(object):
"""
compute area of a quad.
"""
edge = [
(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
edge = [(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
(quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
(quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])
]
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])]
return np.sum(edge) / 2.
def nms(self, dets):
......@@ -157,7 +167,8 @@ class SASTPostProcess(object):
m = quads.shape[0]
gt_tc = np.mean(quads, axis=1) # (m, 2)
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2)
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :],
(1, m, 1)) # (n, m, 2)
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
......@@ -169,26 +180,47 @@ class SASTPostProcess(object):
"""
Estimate sample points number.
"""
eh = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) / 2.0
ew = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0
eh = (np.linalg.norm(quad[0] - quad[3]) +
np.linalg.norm(quad[1] - quad[2])) / 2.0
ew = (np.linalg.norm(quad[0] - quad[1]) +
np.linalg.norm(quad[2] - quad[3])) / 2.0
dense_sample_pts_num = max(2, int(ew))
dense_xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, dense_sample_pts_num,
endpoint=True, dtype=np.float32).astype(np.int32)]
dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1]
estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1))
dense_xy_center_line = xy_text[np.linspace(
0,
xy_text.shape[0] - 1,
dense_sample_pts_num,
endpoint=True,
dtype=np.float32).astype(np.int32)]
dense_xy_center_line_diff = dense_xy_center_line[
1:] - dense_xy_center_line[:-1]
estimate_arc_len = np.sum(
np.linalg.norm(
dense_xy_center_line_diff, axis=1))
sample_pts_num = max(2, int(estimate_arc_len / eh))
return sample_pts_num
def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_w, src_h,
shrink_ratio_of_width=0.3, tcl_map_thresh=0.5, offset_expand=1.0, out_strid=4.0):
def detect_sast(self,
tcl_map,
tvo_map,
tbo_map,
tco_map,
ratio_w,
ratio_h,
src_w,
src_h,
shrink_ratio_of_width=0.3,
tcl_map_thresh=0.5,
offset_expand=1.0,
out_strid=4.0):
"""
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
"""
# restore quad
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map)
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh,
tvo_map)
dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
dets = self.nms(dets)
if dets.shape[0] == 0:
......@@ -202,7 +234,8 @@ class SASTPostProcess(object):
# instance segmentation
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
instance_count, instance_label_map = self.cluster_by_quads_tco(tcl_map, tcl_map_thresh, quads, tco_map)
instance_count, instance_label_map = self.cluster_by_quads_tco(
tcl_map, tcl_map_thresh, quads, tco_map)
# restore single poly with tcl instance.
poly_list = []
......@@ -214,8 +247,8 @@ class SASTPostProcess(object):
continue
#
len1 = float(np.linalg.norm(quad[0] -quad[1]))
len2 = float(np.linalg.norm(quad[1] -quad[2]))
len1 = float(np.linalg.norm(quad[0] - quad[1]))
len2 = float(np.linalg.norm(quad[1] - quad[2]))
min_len = min(len1, len2)
if min_len < 3:
continue
......@@ -231,9 +264,11 @@ class SASTPostProcess(object):
continue
# sort xy_text
left_center_pt = np.array([[(quad[0, 0] + quad[-1, 0]) / 2.0,
left_center_pt = np.array(
[[(quad[0, 0] + quad[-1, 0]) / 2.0,
(quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
right_center_pt = np.array([[(quad[1, 0] + quad[2, 0]) / 2.0,
right_center_pt = np.array(
[[(quad[1, 0] + quad[2, 0]) / 2.0,
(quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
proj_unit_vec = (right_center_pt - left_center_pt) / \
(np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
......@@ -245,28 +280,40 @@ class SASTPostProcess(object):
sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
else:
sample_pts_num = self.sample_pts_num
xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, sample_pts_num,
endpoint=True, dtype=np.float32).astype(np.int32)]
xy_center_line = xy_text[np.linspace(
0,
xy_text.shape[0] - 1,
sample_pts_num,
endpoint=True,
dtype=np.float32).astype(np.int32)]
point_pair_list = []
for x, y in xy_center_line:
# get corresponding offset
offset = tbo_map[y, x, :].reshape(2, 2)
if offset_expand != 1.0:
offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
expand_length = np.clip(offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0)
offset_length = np.linalg.norm(
offset, axis=1, keepdims=True)
expand_length = np.clip(
offset_length * (offset_expand - 1),
a_min=0.5,
a_max=3.0)
offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal
# original point
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1]* out_strid / np.array([ratio_w, ratio_h]).reshape(-1, 2)
point_pair = (ori_yx + offset)[:, ::-1] * out_strid / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
# ndarry: (x, 2), expand poly along width
detected_poly = self.point_pair2poly(point_pair_list)
detected_poly = self.expand_poly_along_width(detected_poly, shrink_ratio_of_width)
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
detected_poly = self.expand_poly_along_width(detected_poly,
shrink_ratio_of_width)
detected_poly[:, 0] = np.clip(
detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(
detected_poly[:, 1], a_min=0, a_max=src_h)
poly_list.append(detected_poly)
return poly_list
......@@ -285,16 +332,24 @@ class SASTPostProcess(object):
img_num = len(shape_list)
poly_lists = []
for ino in range(img_num):
p_score = score_list[ino].transpose((1,2,0))
p_border = border_list[ino].transpose((1,2,0))
p_tvo = tvo_list[ino].transpose((1,2,0))
p_tco = tco_list[ino].transpose((1,2,0))
p_score = score_list[ino].transpose((1, 2, 0))
p_border = border_list[ino].transpose((1, 2, 0))
p_tvo = tvo_list[ino].transpose((1, 2, 0))
p_tco = tco_list[ino].transpose((1, 2, 0))
src_h, src_w, ratio_h, ratio_w = shape_list[ino]
poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h,
poly_list = self.detect_sast(
p_score,
p_tvo,
p_border,
p_tco,
ratio_w,
ratio_h,
src_w,
src_h,
shrink_ratio_of_width=self.shrink_ratio_of_width,
tcl_map_thresh=self.tcl_map_thresh, offset_expand=self.expand_scale)
tcl_map_thresh=self.tcl_map_thresh,
offset_expand=self.expand_scale)
poly_lists.append({'points': np.array(poly_list)})
return poly_lists
此差异已折叠。
import numpy as np
from shapely.geometry import Polygon
#import Polygon
"""
:param det_x: [1, N] Xs of detection's vertices
:param det_y: [1, N] Ys of detection's vertices
:param gt_x: [1, N] Xs of groundtruth's vertices
:param gt_y: [1, N] Ys of groundtruth's vertices
##############
All the calculation of 'AREA' in this script is handled by:
1) First generating a binary mask with the polygon area filled up with 1's
2) Summing up all the 1's
"""
def area(x, y):
polygon = Polygon(np.stack([x, y], axis=1))
return float(polygon.area)
def approx_area_of_intersection(det_x, det_y, gt_x, gt_y):
"""
This helper determine if both polygons are intersecting with each others with an approximation method.
Area of intersection represented by the minimum bounding rectangular [xmin, ymin, xmax, ymax]
"""
det_ymax = np.max(det_y)
det_xmax = np.max(det_x)
det_ymin = np.min(det_y)
det_xmin = np.min(det_x)
gt_ymax = np.max(gt_y)
gt_xmax = np.max(gt_x)
gt_ymin = np.min(gt_y)
gt_xmin = np.min(gt_x)
all_min_ymax = np.minimum(det_ymax, gt_ymax)
all_max_ymin = np.maximum(det_ymin, gt_ymin)
intersect_heights = np.maximum(0.0, (all_min_ymax - all_max_ymin))
all_min_xmax = np.minimum(det_xmax, gt_xmax)
all_max_xmin = np.maximum(det_xmin, gt_xmin)
intersect_widths = np.maximum(0.0, (all_min_xmax - all_max_xmin))
return intersect_heights * intersect_widths
def area_of_intersection(det_x, det_y, gt_x, gt_y):
p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
return float(p1.intersection(p2).area)
def area_of_union(det_x, det_y, gt_x, gt_y):
p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
return float(p1.union(p2).area)
def iou(det_x, det_y, gt_x, gt_y):
return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
area_of_union(det_x, det_y, gt_x, gt_y) + 1.0)
def iod(det_x, det_y, gt_x, gt_y):
"""
This helper determine the fraction of intersection area over detection area
"""
return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
area(det_x, det_y) + 1.0)
此差异已折叠。
此差异已折叠。
"""
Algorithms for computing the skeleton of a binary image
"""
import numpy as np
from scipy import ndimage as ndi
G123_LUT = np.array(
[
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0,
0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0,
1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0,
1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0
],
dtype=np.bool)
G123P_LUT = np.array(
[
0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1,
0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
dtype=np.bool)
def thin(image, max_iter=None):
"""
Perform morphological thinning of a binary image.
Parameters
----------
image : binary (M, N) ndarray
The image to be thinned.
max_iter : int, number of iterations, optional
Regardless of the value of this parameter, the thinned image
is returned immediately if an iteration produces no change.
If this parameter is specified it thus sets an upper bound on
the number of iterations performed.
Returns
-------
out : ndarray of bool
Thinned image.
See also
--------
skeletonize, medial_axis
Notes
-----
This algorithm [1]_ works by making multiple passes over the image,
removing pixels matching a set of criteria designed to thin
connected regions while preserving eight-connected components and
2 x 2 squares [2]_. In each of the two sub-iterations the algorithm
correlates the intermediate skeleton image with a neighborhood mask,
then looks up each neighborhood in a lookup table indicating whether
the central pixel should be deleted in that sub-iteration.
References
----------
.. [1] Z. Guo and R. W. Hall, "Parallel thinning with
two-subiteration algorithms," Comm. ACM, vol. 32, no. 3,
pp. 359-373, 1989. :DOI:`10.1145/62065.62074`
.. [2] Lam, L., Seong-Whan Lee, and Ching Y. Suen, "Thinning
Methodologies-A Comprehensive Survey," IEEE Transactions on
Pattern Analysis and Machine Intelligence, Vol 14, No. 9,
p. 879, 1992. :DOI:`10.1109/34.161346`
Examples
--------
>>> square = np.zeros((7, 7), dtype=np.uint8)
>>> square[1:-1, 2:-2] = 1
>>> square[0, 1] = 1
>>> square
array([[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0]], dtype=uint8)
>>> skel = thin(square)
>>> skel.astype(np.uint8)
array([[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0]], dtype=uint8)
"""
# convert image to uint8 with values in {0, 1}
skel = np.asanyarray(image, dtype=bool).astype(np.uint8)
# neighborhood mask
mask = np.array([[8, 4, 2], [16, 0, 1], [32, 64, 128]], dtype=np.uint8)
# iterate until convergence, up to the iteration limit
max_iter = max_iter or np.inf
n_iter = 0
n_pts_old, n_pts_new = np.inf, np.sum(skel)
while n_pts_old != n_pts_new and n_iter < max_iter:
n_pts_old = n_pts_new
# perform the two "subiterations" described in the paper
for lut in [G123_LUT, G123P_LUT]:
# correlate image with neighborhood mask
N = ndi.correlate(skel, mask, mode='constant')
# take deletion decision from this subiteration's LUT
D = np.take(lut, N)
# perform deletion
skel[D] = 0
n_pts_new = np.sum(skel) # count points after thinning
n_iter += 1
return skel.astype(np.bool)
此差异已折叠。
此差异已折叠。
......@@ -44,6 +44,7 @@ class ArgsParser(ArgumentParser):
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
args.config = '/Users/hongyongjie/project/PaddleOCR/configs/e2e/e2e_r50_vd_pg.yml'
assert args.config is not None, \
"Please specify --config=configure_file_path."
args.opt = self._parse_opt(args.opt)
......@@ -374,7 +375,8 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS'
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PG'
]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册