提交 0a276ad4 编写于 作者: L LDOUBLEV

debug

上级 c342b7a0
...@@ -244,7 +244,7 @@ class KieLabelEncode(object): ...@@ -244,7 +244,7 @@ class KieLabelEncode(object):
def pad_text_indices(self, text_inds): def pad_text_indices(self, text_inds):
"""Pad text index to same length.""" """Pad text index to same length."""
max_len = 100 max_len = 300
recoder_len = max([len(text_ind) for text_ind in text_inds]) recoder_len = max([len(text_ind) for text_ind in text_inds])
padded_text_inds = -np.ones((len(text_inds), max_len), np.int32) padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
for idx, text_ind in enumerate(text_inds): for idx, text_ind in enumerate(text_inds):
...@@ -270,7 +270,7 @@ class KieLabelEncode(object): ...@@ -270,7 +270,7 @@ class KieLabelEncode(object):
np.fill_diagonal(edges, -1) np.fill_diagonal(edges, -1)
labels = np.concatenate([labels, edges], -1) labels = np.concatenate([labels, edges], -1)
padded_text_inds, recoder_len = self.pad_text_indices(text_inds) padded_text_inds, recoder_len = self.pad_text_indices(text_inds)
max_num = 100 max_num = 300
temp_bboxes = np.zeros([max_num, 4]) temp_bboxes = np.zeros([max_num, 4])
h, _ = bboxes.shape h, _ = bboxes.shape
temp_bboxes[:h, :h] = bboxes temp_bboxes[:h, :h] = bboxes
...@@ -278,10 +278,10 @@ class KieLabelEncode(object): ...@@ -278,10 +278,10 @@ class KieLabelEncode(object):
temp_relations = np.zeros([max_num, max_num, 5]) temp_relations = np.zeros([max_num, max_num, 5])
temp_relations[:h, :h, :] = relations temp_relations[:h, :h, :] = relations
temp_padded_text_inds = np.zeros([max_num, 100]) temp_padded_text_inds = np.zeros([max_num, max_num])
temp_padded_text_inds[:h, :] = padded_text_inds temp_padded_text_inds[:h, :] = padded_text_inds
temp_labels = np.zeros([max_num, 100]) temp_labels = np.zeros([max_num, max_num])
temp_labels[:h, :h + 1] = labels temp_labels[:h, :h + 1] = labels
tag = np.array([h, recoder_len]) tag = np.array([h, recoder_len])
......
...@@ -301,33 +301,37 @@ class KieResize(object): ...@@ -301,33 +301,37 @@ class KieResize(object):
img = data['image'] img = data['image']
points = data['points'] points = data['points']
src_h, src_w, _ = img.shape src_h, src_w, _ = img.shape
im_resized, scale_factor, [ratio_h, ratio_w] = self.resize_image(img) im_resized, scale_factor, [ratio_h, ratio_w
], [new_h, new_w] = self.resize_image(img)
resize_points = self.resize_boxes(img, points, scale_factor) resize_points = self.resize_boxes(img, points, scale_factor)
data['ori_image'] = img data['ori_image'] = img
data['ori_boxes'] = points data['ori_boxes'] = points
data['points'] = resize_points data['points'] = resize_points
data['image'] = im_resized data['image'] = im_resized
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) data['shape'] = np.array([new_h, new_w])
return data return data
def resize_image(self, img): def resize_image(self, img):
norm_img = np.zeros([1024, 512, 3], dtype='float32') norm_img = np.zeros([1024, 1024, 3], dtype='float32')
scale = [512, 1024] scale = [512, 1024]
h, w = img.shape[:2] h, w = img.shape[:2]
max_long_edge = max(scale) max_long_edge = max(scale)
max_short_edge = min(scale) max_short_edge = min(scale)
scale_factor = min(max_long_edge / max(h, w), scale_factor = min(max_long_edge / max(h, w),
max_short_edge / min(h, w)) max_short_edge / min(h, w))
new_size = (int(w * float(scale_factor) + 0.5), resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
int(h * float(scale_factor) + 0.5)) scale_factor) + 0.5)
im = cv2.resize(img, new_size) max_stride = 32
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(img, (resize_w, resize_h))
new_h, new_w = im.shape[:2] new_h, new_w = im.shape[:2]
w_scale = new_w / w w_scale = new_w / w
h_scale = new_h / h h_scale = new_h / h
scale_factor = np.array( scale_factor = np.array(
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32) [w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
norm_img[:new_h, :new_w, :] = im norm_img[:new_h, :new_w, :] = im
return norm_img, scale_factor, [h_scale, w_scale] return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
def resize_boxes(self, im, points, scale_factor): def resize_boxes(self, im, points, scale_factor):
points = points * scale_factor points = points * scale_factor
......
...@@ -17,6 +17,7 @@ from __future__ import division ...@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
import paddle
__all__ = ['KIEMetric'] __all__ = ['KIEMetric']
...@@ -25,16 +26,19 @@ class KIEMetric(object): ...@@ -25,16 +26,19 @@ class KIEMetric(object):
def __init__(self, main_indicator='hmean', **kwargs): def __init__(self, main_indicator='hmean', **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.reset() self.reset()
self.node = []
self.gt = []
def __call__(self, preds, batch, **kwargs): def __call__(self, preds, batch, **kwargs):
nodes, _ = preds nodes, _ = preds
gts, tag = batch[4].squeeze(0), batch[5].tolist()[0] gts, tag = batch[4].squeeze(0), batch[5].tolist()[0]
gts = gts[:tag[0], :1].reshape([-1]) gts = gts[:tag[0], :1].reshape([-1])
result = self.compute_f1_score(nodes, gts) self.node.append(nodes.numpy())
self.results.append(result) self.gt.append(gts)
# result = self.compute_f1_score(nodes, gts)
# self.results.append(result)
def compute_f1_score(self, preds, gts): def compute_f1_score(self, preds, gts):
preds = preds.numpy()
ignores = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25] ignores = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25]
C = preds.shape[1] C = preds.shape[1]
classes = np.array(sorted(set(range(C)) - set(ignores))) classes = np.array(sorted(set(range(C)) - set(ignores)))
...@@ -48,13 +52,19 @@ class KIEMetric(object): ...@@ -48,13 +52,19 @@ class KIEMetric(object):
return f1[classes] return f1[classes]
def combine_results(self, results): def combine_results(self, results):
data = {'hmean': np.mean(results[0])} node = np.concatenate(self.node, 0)
gts = np.concatenate(self.gt, 0)
results = self.compute_f1_score(node, gts)
data = {'hmean': results.mean()}
return data return data
def get_metric(self): def get_metric(self):
metircs = self.combine_results(self.results) metircs = self.combine_results(self.results)
self.reset() self.reset()
return metircs return metircs
def reset(self): def reset(self):
self.results = [] # clear results self.results = [] # clear results
self.node = []
self.gt = []
...@@ -18,6 +18,8 @@ from __future__ import print_function ...@@ -18,6 +18,8 @@ from __future__ import print_function
import paddle import paddle
from paddle import nn from paddle import nn
import numpy as np
import cv2
__all__ = ["Kie_backbone"] __all__ = ["Kie_backbone"]
...@@ -26,11 +28,21 @@ class Encoder(nn.Layer): ...@@ -26,11 +28,21 @@ class Encoder(nn.Layer):
def __init__(self, num_channels, num_filters): def __init__(self, num_channels, num_filters):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.conv1 = nn.Conv2D( self.conv1 = nn.Conv2D(
num_channels, num_filters, kernel_size=3, stride=1, padding=1) num_channels,
num_filters,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False)
self.bn1 = nn.BatchNorm(num_filters, act='relu') self.bn1 = nn.BatchNorm(num_filters, act='relu')
self.conv2 = nn.Conv2D( self.conv2 = nn.Conv2D(
num_filters, num_filters, kernel_size=3, stride=1, padding=1) num_filters,
num_filters,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False)
self.bn2 = nn.BatchNorm(num_filters, act='relu') self.bn2 = nn.BatchNorm(num_filters, act='relu')
self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
...@@ -41,28 +53,45 @@ class Encoder(nn.Layer): ...@@ -41,28 +53,45 @@ class Encoder(nn.Layer):
x = self.conv2(x) x = self.conv2(x)
x = self.bn2(x) x = self.bn2(x)
x_pooled = self.pool(x) x_pooled = self.pool(x)
return x, x_pooled return x, x_pooled
class Decoder(nn.Layer): class Decoder(nn.Layer):
def __init__(self, num_channels, num_filters): def __init__(self, num_channels, num_filters):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.up = nn.Conv2DTranspose(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=2,
stride=2)
self.conv1 = nn.Conv2D( self.conv1 = nn.Conv2D(
num_channels, num_filters, kernel_size=3, stride=1, padding=1) num_channels,
num_filters,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False)
self.bn1 = nn.BatchNorm(num_filters, act='relu') self.bn1 = nn.BatchNorm(num_filters, act='relu')
self.conv2 = nn.Conv2D( self.conv2 = nn.Conv2D(
num_filters, num_filters, kernel_size=3, stride=1, padding=1) num_filters,
num_filters,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False)
self.bn2 = nn.BatchNorm(num_filters, act='relu') self.bn2 = nn.BatchNorm(num_filters, act='relu')
self.conv0 = nn.Conv2D(
num_channels,
num_filters,
kernel_size=1,
stride=1,
padding=0,
bias_attr=False)
self.bn0 = nn.BatchNorm(num_filters, act='relu')
def forward(self, inputs_prev, inputs): def forward(self, inputs_prev, inputs):
x = self.up(inputs) x = self.conv0(inputs)
x = self.bn0(x)
x = paddle.nn.functional.interpolate(
x, scale_factor=2, mode='bilinear', align_corners=False)
x = paddle.concat([inputs_prev, x], axis=1) x = paddle.concat([inputs_prev, x], axis=1)
x = self.conv1(x) x = self.conv1(x)
x = self.bn1(x) x = self.bn1(x)
...@@ -80,18 +109,18 @@ class UNet(nn.Layer): ...@@ -80,18 +109,18 @@ class UNet(nn.Layer):
self.down4 = Encoder(num_channels=64, num_filters=128) self.down4 = Encoder(num_channels=64, num_filters=128)
self.down5 = Encoder(num_channels=128, num_filters=256) self.down5 = Encoder(num_channels=128, num_filters=256)
self.up4 = Decoder(256, 128)
self.up3 = Decoder(128, 64)
self.up2 = Decoder(64, 32)
self.up1 = Decoder(32, 16) self.up1 = Decoder(32, 16)
self.up2 = Decoder(64, 32)
self.up3 = Decoder(128, 64)
self.up4 = Decoder(256, 128)
self.out_channels = 16 self.out_channels = 16
def forward(self, inputs): def forward(self, inputs):
x1, x = self.down1(inputs) x1, _ = self.down1(inputs)
x2, x = self.down2(x) _, x2 = self.down2(x1)
x3, x = self.down3(x) _, x3 = self.down3(x2)
x4, x = self.down4(x) _, x4 = self.down4(x3)
x5, x = self.down5(x) _, x5 = self.down5(x4)
x = self.up4(x4, x5) x = self.up4(x4, x5)
x = self.up3(x3, x) x = self.up3(x3, x)
...@@ -117,10 +146,13 @@ class Kie_backbone(nn.Layer): ...@@ -117,10 +146,13 @@ class Kie_backbone(nn.Layer):
rois_num = paddle.to_tensor(rois_num, dtype='int32') rois_num = paddle.to_tensor(rois_num, dtype='int32')
return rois, rois_num return rois, rois_num
def pre_process(self, relations, texts, gt_bboxes, tag): def pre_process(self, img, relations, texts, gt_bboxes, tag, img_size):
relations, texts, gt_bboxes, tag = relations.numpy(), texts.numpy( img, relations, texts, gt_bboxes, tag, img_size = img.numpy(
), gt_bboxes.numpy(), tag.numpy().tolist() ), relations.numpy(), texts.numpy(), gt_bboxes.numpy(), tag.numpy(
).tolist(), img_size.numpy()
temp_relations, temp_texts, temp_gt_bboxes = [], [], [] temp_relations, temp_texts, temp_gt_bboxes = [], [], []
h, w = int(np.max(img_size[:, 0])), int(np.max(img_size[:, 1]))
img = paddle.to_tensor(img[:, :, :h, :w])
batch = len(tag) batch = len(tag)
for i in range(batch): for i in range(batch):
num, recoder_len = tag[i][0], tag[i][1] num, recoder_len = tag[i][0], tag[i][1]
...@@ -133,13 +165,22 @@ class Kie_backbone(nn.Layer): ...@@ -133,13 +165,22 @@ class Kie_backbone(nn.Layer):
temp_gt_bboxes.append( temp_gt_bboxes.append(
paddle.to_tensor( paddle.to_tensor(
gt_bboxes[i, :num, ...], dtype='float32')) gt_bboxes[i, :num, ...], dtype='float32'))
return temp_relations, temp_texts, temp_gt_bboxes return img, temp_relations, temp_texts, temp_gt_bboxes
def forward(self, inputs): def forward(self, inputs):
img, relations, texts, gt_bboxes, tag = inputs[0], inputs[1], inputs[ img, relations, texts, gt_bboxes, tag, img_size = inputs[0], inputs[
2], inputs[3], inputs[5] 1], inputs[2], inputs[3], inputs[5], inputs[-1]
relations, texts, gt_bboxes = self.pre_process(relations, texts, img, relations, texts, gt_bboxes = self.pre_process(
gt_bboxes, tag) img, relations, texts, gt_bboxes, tag, img_size)
# for i in range(4):
# img_t = (img[i].numpy().transpose([1, 2, 0]) * 255.0).astype('uint8')
# img_t = img_t.copy()
# gt_bboxes_t = gt_bboxes[i].cpu().numpy()
# box = gt_bboxes_t.astype(np.int32).reshape((-1, 1, 2))
# cv2.polylines(img_t, [box], True, color=(255, 255, 0), thickness=1)
# cv2.imwrite("/Users/hongyongjie/project/PaddleOCR/output/{}.png".format(i), img_t)
# # cv2.imwrite("/Users/hongyongjie/project/PaddleOCR/output/{}.png".format(i), img_t * 255.0)
# exit()
x = self.img_feat(img) x = self.img_feat(img)
boxes, rois_num = self.bbox2roi(gt_bboxes) boxes, rois_num = self.bbox2roi(gt_bboxes)
feats = paddle.fluid.layers.roi_align( feats = paddle.fluid.layers.roi_align(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册