提交 bb49e1a5 编写于 作者: J Jethong

ADD PGNet_v2

上级 1f76f449
......@@ -37,6 +37,7 @@ class ClsLabelEncode(object):
class E2ELabelEncode(object):
def __init__(self, label_list, **kwargs):
self.label_list = label_list
self.max_len = 50
def __call__(self, data):
text_label_index_list, temp_text = [], []
......@@ -47,7 +48,7 @@ class E2ELabelEncode(object):
for c_ in text:
if c_ in self.label_list:
temp_text.append(self.label_list.index(c_))
temp_text = temp_text + [36] * (50 - len(temp_text))
temp_text = temp_text + [36] * (self.max_len - len(temp_text))
text_label_index_list.append(temp_text)
data['strs'] = np.array(text_label_index_list)
return data
......
......@@ -32,16 +32,6 @@ class E2EMetric(object):
self.reset()
def __call__(self, preds, batch, **kwargs):
'''
batch: a list produced by dataloaders.
image: np.ndarray of shape (N, C, H, W).
ratio_list: np.ndarray of shape(N,2)
polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
preds: a list of dict produced by post process
points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
'''
gt_polyons_batch = batch[2]
temp_gt_strs_batch = batch[3]
ignore_tags_batch = batch[4]
......@@ -72,13 +62,6 @@ class E2EMetric(object):
self.results.append(result)
def get_metric(self):
"""
return metrics {
'precision': 0,
'recall': 0,
'hmean': 0
}
"""
metircs = combine_results(self.results)
self.reset()
return metircs
......
......@@ -106,172 +106,212 @@ class DeConvBNLayer(nn.Layer):
return x
class FPN_Up_Fusion(nn.Layer):
def __init__(self, in_channels):
super(FPN_Up_Fusion, self).__init__()
in_channels = in_channels[::-1]
out_channels = [256, 256, 192, 192, 128]
class PGFPN(nn.Layer):
def __init__(self, in_channels, **kwargs):
super(PGFPN, self).__init__()
num_inputs = [2048, 2048, 1024, 512, 256]
num_outputs = [256, 256, 192, 192, 128]
self.out_channels = 128
# print(in_channels)
self.conv_bn_layer_1 = ConvBNLayer(
in_channels=3,
out_channels=32,
kernel_size=3,
stride=1,
act=None,
name='FPN_d1')
self.conv_bn_layer_2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
act=None,
name='FPN_d2')
self.conv_bn_layer_3 = ConvBNLayer(
in_channels=256,
out_channels=128,
kernel_size=3,
stride=1,
act=None,
name='FPN_d3')
self.conv_bn_layer_4 = ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=2,
act=None,
name='FPN_d4')
self.conv_bn_layer_5 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
act='relu',
name='FPN_d5')
self.conv_bn_layer_6 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=3,
stride=2,
act=None,
name='FPN_d6')
self.conv_bn_layer_7 = ConvBNLayer(
in_channels=128,
out_channels=128,
kernel_size=3,
stride=1,
act='relu',
name='FPN_d7')
self.conv_bn_layer_8 = ConvBNLayer(
in_channels=128,
out_channels=128,
kernel_size=1,
stride=1,
act=None,
name='FPN_d8')
self.h0_conv = ConvBNLayer(
in_channels[0], out_channels[0], 1, 1, act=None, name='conv_h0')
self.h1_conv = ConvBNLayer(
in_channels[1], out_channels[1], 1, 1, act=None, name='conv_h1')
self.h2_conv = ConvBNLayer(
in_channels[2], out_channels[2], 1, 1, act=None, name='conv_h2')
self.h3_conv = ConvBNLayer(
in_channels[3], out_channels[3], 1, 1, act=None, name='conv_h3')
self.h4_conv = ConvBNLayer(
in_channels[4], out_channels[4], 1, 1, act=None, name='conv_h4')
self.conv_h0 = ConvBNLayer(
in_channels=num_inputs[0],
out_channels=num_outputs[0],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(0))
self.conv_h1 = ConvBNLayer(
in_channels=num_inputs[1],
out_channels=num_outputs[1],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(1))
self.conv_h2 = ConvBNLayer(
in_channels=num_inputs[2],
out_channels=num_outputs[2],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(2))
self.conv_h3 = ConvBNLayer(
in_channels=num_inputs[3],
out_channels=num_outputs[3],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(3))
self.conv_h4 = ConvBNLayer(
in_channels=num_inputs[4],
out_channels=num_outputs[4],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(4))
self.dconv0 = DeConvBNLayer(
in_channels=out_channels[0],
out_channels=out_channels[1],
in_channels=num_outputs[0],
out_channels=num_outputs[0 + 1],
name="dconv_{}".format(0))
self.dconv1 = DeConvBNLayer(
in_channels=out_channels[1],
out_channels=out_channels[2],
in_channels=num_outputs[1],
out_channels=num_outputs[1 + 1],
act=None,
name="dconv_{}".format(1))
self.dconv2 = DeConvBNLayer(
in_channels=out_channels[2],
out_channels=out_channels[3],
in_channels=num_outputs[2],
out_channels=num_outputs[2 + 1],
act=None,
name="dconv_{}".format(2))
self.dconv3 = DeConvBNLayer(
in_channels=out_channels[3],
out_channels=out_channels[4],
in_channels=num_outputs[3],
out_channels=num_outputs[3 + 1],
act=None,
name="dconv_{}".format(3))
self.conv_g1 = ConvBNLayer(
in_channels=out_channels[1],
out_channels=out_channels[1],
in_channels=num_outputs[1],
out_channels=num_outputs[1],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(1))
self.conv_g2 = ConvBNLayer(
in_channels=out_channels[2],
out_channels=out_channels[2],
in_channels=num_outputs[2],
out_channels=num_outputs[2],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(2))
self.conv_g3 = ConvBNLayer(
in_channels=out_channels[3],
out_channels=out_channels[3],
in_channels=num_outputs[3],
out_channels=num_outputs[3],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(3))
self.conv_g4 = ConvBNLayer(
in_channels=out_channels[4],
out_channels=out_channels[4],
in_channels=num_outputs[4],
out_channels=num_outputs[4],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(4))
self.convf = ConvBNLayer(
in_channels=out_channels[4],
out_channels=out_channels[4],
in_channels=num_outputs[4],
out_channels=num_outputs[4],
kernel_size=1,
stride=1,
act=None,
name="conv_f{}".format(4))
def _add_relu(self, x1, x2):
x = paddle.add(x=x1, y=x2)
x = F.relu(x)
return x
def forward(self, x):
f = x[2:][::-1]
h0 = self.h0_conv(f[0])
h1 = self.h1_conv(f[1])
h2 = self.h2_conv(f[2])
h3 = self.h3_conv(f[3])
h4 = self.h4_conv(f[4])
c0, c1, c2, c3, c4, c5, c6 = x
# FPN_Down_Fusion
f = [c0, c1, c2]
g = [None, None, None]
h = [None, None, None]
h[0] = self.conv_bn_layer_1(f[0])
h[1] = self.conv_bn_layer_2(f[1])
h[2] = self.conv_bn_layer_3(f[2])
g0 = self.dconv0(h0)
g[0] = self.conv_bn_layer_4(h[0])
g[1] = paddle.add(g[0], h[1])
g[1] = F.relu(g[1])
g[1] = self.conv_bn_layer_5(g[1])
g[1] = self.conv_bn_layer_6(g[1])
g1 = self.dconv2(self.conv_g2(self._add_relu(g0, h1)))
g2 = self.dconv2(self.conv_g2(self._add_relu(g1, h2)))
g3 = self.dconv3(self.conv_g2(self._add_relu(g2, h3)))
g4 = self.dconv4(self.conv_g2(self._add_relu(g3, h4)))
return g4
g[2] = paddle.add(g[1], h[2])
g[2] = F.relu(g[2])
g[2] = self.conv_bn_layer_7(g[2])
f_down = self.conv_bn_layer_8(g[2])
# FPN UP Fusion
f1 = [c6, c5, c4, c3, c2]
g = [None, None, None, None, None]
h = [None, None, None, None, None]
h[0] = self.conv_h0(f1[0])
h[1] = self.conv_h1(f1[1])
h[2] = self.conv_h2(f1[2])
h[3] = self.conv_h3(f1[3])
h[4] = self.conv_h4(f1[4])
class FPN_Down_Fusion(nn.Layer):
def __init__(self, in_channels):
super(FPN_Down_Fusion, self).__init__()
out_channels = [32, 64, 128]
g[0] = self.dconv0(h[0])
g[1] = paddle.add(g[0], h[1])
g[1] = F.relu(g[1])
g[1] = self.conv_g1(g[1])
g[1] = self.dconv1(g[1])
self.h0_conv = ConvBNLayer(
in_channels[0], out_channels[0], 3, 1, act=None, name='FPN_d1')
self.h1_conv = ConvBNLayer(
in_channels[1], out_channels[1], 3, 1, act=None, name='FPN_d2')
self.h2_conv = ConvBNLayer(
in_channels[2], out_channels[2], 3, 1, act=None, name='FPN_d3')
g[2] = paddle.add(g[1], h[2])
g[2] = F.relu(g[2])
g[2] = self.conv_g2(g[2])
g[2] = self.dconv2(g[2])
self.g0_conv = ConvBNLayer(
out_channels[0], out_channels[1], 3, 2, act=None, name='FPN_d4')
g[3] = paddle.add(g[2], h[3])
g[3] = F.relu(g[3])
g[3] = self.conv_g3(g[3])
g[3] = self.dconv3(g[3])
self.g1_conv = nn.Sequential(
ConvBNLayer(
out_channels[1],
out_channels[1],
3,
1,
act='relu',
name='FPN_d5'),
ConvBNLayer(
out_channels[1], out_channels[2], 3, 2, act=None,
name='FPN_d6'))
self.g2_conv = nn.Sequential(
ConvBNLayer(
out_channels[2],
out_channels[2],
3,
1,
act='relu',
name='FPN_d7'),
ConvBNLayer(
out_channels[2], out_channels[2], 1, 1, act=None,
name='FPN_d8'))
def forward(self, x):
f = x[:3]
h0 = self.h0_conv(f[0])
h1 = self.h1_conv(f[1])
h2 = self.h2_conv(f[2])
g0 = self.g0_conv(h0)
g1 = paddle.add(x=g0, y=h1)
g1 = F.relu(g1)
g1 = self.g1_conv(g1)
g2 = paddle.add(x=g1, y=h2)
g2 = F.relu(g2)
g2 = self.g2_conv(g2)
return g2
class PGFPN(nn.Layer):
def __init__(self, in_channels, with_cab=False, **kwargs):
super(PGFPN, self).__init__()
self.in_channels = in_channels
self.with_cab = with_cab
self.FPN_Down_Fusion = FPN_Down_Fusion(self.in_channels)
self.FPN_Up_Fusion = FPN_Up_Fusion(self.in_channels)
self.out_channels = 128
def forward(self, x):
# down fpn
f_down = self.FPN_Down_Fusion(x)
# up fpn
f_up = self.FPN_Up_Fusion(x)
# fusion
f_common = paddle.add(x=f_down, y=f_up)
g[4] = paddle.add(x=g[3], y=h[4])
g[4] = F.relu(g[4])
g[4] = self.conv_g4(g[4])
f_up = self.convf(g[4])
f_common = paddle.add(f_down, f_up)
f_common = F.relu(f_common)
return f_common
from os import listdir
import os, sys
from scipy import io
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
from tqdm import tqdm
try: # python2
range = xrange
......@@ -862,16 +871,3 @@ def combine_results(all_data):
'f_score_e2e': f_score_e2e
}
return final
# a = [1526, 642, 1565, 629, 1579, 627, 1593, 625, 1607, 623, 1620, 622, 1634, 620, 1659, 620, 1654, 681, 1631, 680, 1618,
# 681, 1606, 681, 1594, 681, 1584, 682, 1573, 685, 1542, 694]
# gt_dict = [{'points': np.array(a).reshape(-1, 2), 'text': 'MILK'}]
# pred_dict = [{'points': np.array(a), 'text': 'ccc'},
# {'points': np.array(a), 'text': 'ccf'}]
# result = []
# for i in range(2):
# result.append(get_socre(gt_dict, pred_dict))
# print(111)
# a = combine_results(result)
# print(a)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from shapely.geometry import Polygon
#import Polygon
"""
:param det_x: [1, N] Xs of detection's vertices
:param det_y: [1, N] Ys of detection's vertices
......
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains various CTC decoders."""
from __future__ import absolute_import
from __future__ import division
......
"""
Algorithms for computing the skeleton of a binary image
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from scipy import ndimage as ndi
......
import os
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import cv2
import time
def visualize_e2e_result(im_fn, poly_list, seq_strs, src_im):
"""
"""
result_path = './out'
im_basename = os.path.basename(im_fn)
im_prefix = im_basename[:im_basename.rfind('.')]
vis_det_img = src_im.copy()
valid_set = 'partvgg'
gt_dir = "/Users/hongyongjie/Downloads/part_vgg_synth/train"
text_path = os.path.join(gt_dir, im_prefix + '.txt')
fid = open(text_path, 'r')
lines = [line.strip() for line in fid.readlines()]
for line in lines:
if valid_set == 'partvgg':
tokens = line.strip().split('\t')[0].split(',')
# tokens = line.strip().split(',')
coords = tokens[:]
coords = list(map(float, coords))
gt_poly = np.array(coords).reshape(1, 4, 2)
elif valid_set == 'totaltext':
tokens = line.strip().split('\t')[0].split(',')
coords = tokens[:]
coords_len = len(coords) / 2
coords = list(map(float, coords))
gt_poly = np.array(coords).reshape(1, coords_len, 2)
cv2.polylines(
vis_det_img,
np.array(gt_poly).astype(np.int32),
isClosed=True,
color=(255, 0, 0),
thickness=2)
for detected_poly, recognized_str in zip(poly_list, seq_strs):
cv2.polylines(
vis_det_img,
np.array(detected_poly[np.newaxis, ...]).astype(np.int32),
isClosed=True,
color=(0, 0, 255),
thickness=2)
cv2.putText(
vis_det_img,
recognized_str,
org=(int(detected_poly[0, 0]), int(detected_poly[0, 1])),
fontFace=cv2.FONT_HERSHEY_COMPLEX,
fontScale=0.7,
color=(0, 255, 0),
thickness=1)
if not os.path.exists(result_path):
os.makedirs(result_path)
cv2.imwrite("{}/{}_detection.jpg".format(result_path, im_prefix),
vis_det_img)
def visualization_output(src_image,
f_tcl,
f_chars,
output_dir,
image_prefix=None):
"""
"""
# restore BGR image, CHW -> HWC
im_mean = [0.485, 0.456, 0.406]
im_std = [0.229, 0.224, 0.225]
im_mean = np.array(im_mean).reshape((3, 1, 1))
im_std = np.array(im_std).reshape((3, 1, 1))
src_image *= im_std
src_image += im_mean
src_image = src_image.transpose([1, 2, 0])
src_image = src_image[:, :, ::-1] * 255 # BGR -> RGB
H, W, _ = src_image.shape
file_prefix = image_prefix if image_prefix is not None else str(
int(time.time() * 1000))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# visualization f_tcl
tcl_file_name = os.path.join(output_dir, file_prefix + '_0_tcl.jpg')
vis_tcl_img = src_image.copy()
f_tcl_resized = cv2.resize(f_tcl, dsize=(W, H))
vis_tcl_img[:, :, 1] = f_tcl_resized * 255
cv2.imwrite(tcl_file_name, vis_tcl_img)
# visualization char maps
vis_char_img = src_image.copy()
# CHW -> HWC
char_file_name = os.path.join(output_dir, file_prefix + '_1_chars.jpg')
f_chars = np.argmax(f_chars, axis=2)[:, :, np.newaxis].astype('float32')
f_chars[f_chars < 95] = 1.0
f_chars[f_chars == 95] = 0.0
f_chars_resized = cv2.resize(f_chars, dsize=(W, H))
vis_char_img[:, :, 1] = f_chars_resized * 255
cv2.imwrite(char_file_name, vis_char_img)
def visualize_point_result(im_fn, point_list, point_pair_list, src_im, gt_dir,
result_path):
"""
"""
im_basename = os.path.basename(im_fn)
im_prefix = im_basename[:im_basename.rfind('.')]
vis_det_img = src_im.copy()
# draw gt bbox on the image.
text_path = os.path.join(gt_dir, im_prefix + '.txt')
fid = open(text_path, 'r')
lines = [line.strip() for line in fid.readlines()]
for line in lines:
tokens = line.strip().split('\t')
coords = tokens[0].split(',')
coords_len = len(coords)
coords = list(map(float, coords))
gt_poly = np.array(coords).reshape(1, coords_len / 2, 2)
cv2.polylines(
vis_det_img,
np.array(gt_poly).astype(np.int32),
isClosed=True,
color=(255, 255, 255),
thickness=1)
for point, point_pair in zip(point_list, point_pair_list):
cv2.line(
vis_det_img,
tuple(point_pair[0]),
tuple(point_pair[1]), (0, 255, 255),
thickness=1)
cv2.circle(vis_det_img, tuple(point), 2, (0, 0, 255))
cv2.circle(vis_det_img, tuple(point_pair[0]), 2, (255, 0, 0))
cv2.circle(vis_det_img, tuple(point_pair[1]), 2, (0, 255, 0))
if not os.path.exists(result_path):
os.makedirs(result_path)
cv2.imwrite("{}/{}_border_points.jpg".format(result_path, im_prefix),
vis_det_img)
def resize_image(im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
......@@ -295,49 +169,3 @@ def norm2(x, axis=None):
def cos(p1, p2):
return (p1 * p2).sum() / (norm2(p1) * norm2(p2))
def generate_direction_info(image_fn,
H,
W,
ratio_h,
ratio_w,
max_length=640,
out_scale=4,
gt_dir=None):
"""
"""
im_basename = os.path.basename(image_fn)
im_prefix = im_basename[:im_basename.rfind('.')]
instance_direction_map = np.zeros(shape=[H // out_scale, W // out_scale, 3])
if gt_dir is None:
gt_dir = '/home/vis/huangzuming/data/SYNTH_DATA/part_vgg_synth_icdar/processed/val/poly'
# get gt label map
text_path = os.path.join(gt_dir, im_prefix + '.txt')
fid = open(text_path, 'r')
lines = [line.strip() for line in fid.readlines()]
for label_idx, line in enumerate(lines, start=1):
coords, txt = line.strip().split('\t')
if txt == '###':
continue
tokens = coords.strip().split(',')
coords = list(map(float, tokens))
poly = np.array(coords).reshape(4, 2) * np.array(
[ratio_w, ratio_h]).reshape(1, 2) / out_scale
mid_idx = poly.shape[0] // 2
direct_vector = (
(poly[mid_idx] + poly[mid_idx - 1]) - (poly[0] + poly[-1])) / 2.0
direct_vector /= len(txt)
# l2_distance = norm2(direct_vector)
# avg_char_distance = l2_distance / len(txt)
avg_char_distance = 1.0
direct_label = (direct_vector[0], direct_vector[1], avg_char_distance)
cv2.fillPoly(instance_direction_map,
poly.round().astype(np.int32)[np.newaxis, :, :],
direct_label)
instance_direction_map = instance_direction_map.transpose([2, 0, 1])
return instance_direction_map[:2, ...]
......@@ -44,7 +44,6 @@ class ArgsParser(ArgumentParser):
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
args.config = '/Users/hongyongjie/project/PaddleOCR/configs/e2e/e2e_r50_vd_pg.yml'
assert args.config is not None, \
"Please specify --config=configure_file_path."
args.opt = self._parse_opt(args.opt)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册