diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 2804330ec0ac93b651cf9fdf2a6c4c2308e153a9..f2443f65061cfbb14c7fc113e7fa1fad6afdedda 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -244,7 +244,7 @@ class KieLabelEncode(object): def pad_text_indices(self, text_inds): """Pad text index to same length.""" - max_len = 100 + max_len = 300 recoder_len = max([len(text_ind) for text_ind in text_inds]) padded_text_inds = -np.ones((len(text_inds), max_len), np.int32) for idx, text_ind in enumerate(text_inds): @@ -270,7 +270,7 @@ class KieLabelEncode(object): np.fill_diagonal(edges, -1) labels = np.concatenate([labels, edges], -1) padded_text_inds, recoder_len = self.pad_text_indices(text_inds) - max_num = 100 + max_num = 300 temp_bboxes = np.zeros([max_num, 4]) h, _ = bboxes.shape temp_bboxes[:h, :h] = bboxes @@ -278,10 +278,10 @@ class KieLabelEncode(object): temp_relations = np.zeros([max_num, max_num, 5]) temp_relations[:h, :h, :] = relations - temp_padded_text_inds = np.zeros([max_num, 100]) + temp_padded_text_inds = np.zeros([max_num, max_num]) temp_padded_text_inds[:h, :] = padded_text_inds - temp_labels = np.zeros([max_num, 100]) + temp_labels = np.zeros([max_num, max_num]) temp_labels[:h, :h + 1] = labels tag = np.array([h, recoder_len]) diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index ca1a4165d8561ac50a935eae0bfdb930c6b0abbf..2ea498c714d5db6f2efdbdcf511e33cea6235c86 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -301,33 +301,37 @@ class KieResize(object): img = data['image'] points = data['points'] src_h, src_w, _ = img.shape - im_resized, scale_factor, [ratio_h, ratio_w] = self.resize_image(img) + im_resized, scale_factor, [ratio_h, ratio_w + ], [new_h, new_w] = self.resize_image(img) resize_points = self.resize_boxes(img, points, scale_factor) data['ori_image'] = img data['ori_boxes'] = points data['points'] = resize_points data['image'] = im_resized - data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) + data['shape'] = np.array([new_h, new_w]) return data def resize_image(self, img): - norm_img = np.zeros([1024, 512, 3], dtype='float32') + norm_img = np.zeros([1024, 1024, 3], dtype='float32') scale = [512, 1024] h, w = img.shape[:2] max_long_edge = max(scale) max_short_edge = min(scale) scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w)) - new_size = (int(w * float(scale_factor) + 0.5), - int(h * float(scale_factor) + 0.5)) - im = cv2.resize(img, new_size) + resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float( + scale_factor) + 0.5) + max_stride = 32 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(img, (resize_w, resize_h)) new_h, new_w = im.shape[:2] w_scale = new_w / w h_scale = new_h / h scale_factor = np.array( [w_scale, h_scale, w_scale, h_scale], dtype=np.float32) norm_img[:new_h, :new_w, :] = im - return norm_img, scale_factor, [h_scale, w_scale] + return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w] def resize_boxes(self, im, points, scale_factor): points = points * scale_factor diff --git a/ppocr/metrics/kie_metric.py b/ppocr/metrics/kie_metric.py index b282491813327289e9c96f02d8b077da1415f28c..761965cfcc25d2a6de30342769d01b36d6212d98 100644 --- a/ppocr/metrics/kie_metric.py +++ b/ppocr/metrics/kie_metric.py @@ -17,6 +17,7 @@ from __future__ import division from __future__ import print_function import numpy as np +import paddle __all__ = ['KIEMetric'] @@ -25,16 +26,19 @@ class KIEMetric(object): def __init__(self, main_indicator='hmean', **kwargs): self.main_indicator = main_indicator self.reset() + self.node = [] + self.gt = [] def __call__(self, preds, batch, **kwargs): nodes, _ = preds gts, tag = batch[4].squeeze(0), batch[5].tolist()[0] gts = gts[:tag[0], :1].reshape([-1]) - result = self.compute_f1_score(nodes, gts) - self.results.append(result) + self.node.append(nodes.numpy()) + self.gt.append(gts) + # result = self.compute_f1_score(nodes, gts) + # self.results.append(result) def compute_f1_score(self, preds, gts): - preds = preds.numpy() ignores = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25] C = preds.shape[1] classes = np.array(sorted(set(range(C)) - set(ignores))) @@ -48,13 +52,19 @@ class KIEMetric(object): return f1[classes] def combine_results(self, results): - data = {'hmean': np.mean(results[0])} + node = np.concatenate(self.node, 0) + gts = np.concatenate(self.gt, 0) + results = self.compute_f1_score(node, gts) + data = {'hmean': results.mean()} return data def get_metric(self): + metircs = self.combine_results(self.results) self.reset() return metircs def reset(self): self.results = [] # clear results + self.node = [] + self.gt = [] diff --git a/ppocr/modeling/backbones/kie_unet_sdmgr.py b/ppocr/modeling/backbones/kie_unet_sdmgr.py index a3c91d903b1d63221509a4f80f69806ce2f7920a..62bae2ea12c43fffe4153878a05b97990b02c502 100644 --- a/ppocr/modeling/backbones/kie_unet_sdmgr.py +++ b/ppocr/modeling/backbones/kie_unet_sdmgr.py @@ -18,6 +18,8 @@ from __future__ import print_function import paddle from paddle import nn +import numpy as np +import cv2 __all__ = ["Kie_backbone"] @@ -26,11 +28,21 @@ class Encoder(nn.Layer): def __init__(self, num_channels, num_filters): super(Encoder, self).__init__() self.conv1 = nn.Conv2D( - num_channels, num_filters, kernel_size=3, stride=1, padding=1) + num_channels, + num_filters, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False) self.bn1 = nn.BatchNorm(num_filters, act='relu') self.conv2 = nn.Conv2D( - num_filters, num_filters, kernel_size=3, stride=1, padding=1) + num_filters, + num_filters, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False) self.bn2 = nn.BatchNorm(num_filters, act='relu') self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) @@ -41,28 +53,45 @@ class Encoder(nn.Layer): x = self.conv2(x) x = self.bn2(x) x_pooled = self.pool(x) - return x, x_pooled class Decoder(nn.Layer): def __init__(self, num_channels, num_filters): super(Decoder, self).__init__() - self.up = nn.Conv2DTranspose( - in_channels=num_channels, - out_channels=num_filters, - kernel_size=2, - stride=2) + self.conv1 = nn.Conv2D( - num_channels, num_filters, kernel_size=3, stride=1, padding=1) + num_channels, + num_filters, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False) self.bn1 = nn.BatchNorm(num_filters, act='relu') self.conv2 = nn.Conv2D( - num_filters, num_filters, kernel_size=3, stride=1, padding=1) + num_filters, + num_filters, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False) self.bn2 = nn.BatchNorm(num_filters, act='relu') + self.conv0 = nn.Conv2D( + num_channels, + num_filters, + kernel_size=1, + stride=1, + padding=0, + bias_attr=False) + self.bn0 = nn.BatchNorm(num_filters, act='relu') + def forward(self, inputs_prev, inputs): - x = self.up(inputs) + x = self.conv0(inputs) + x = self.bn0(x) + x = paddle.nn.functional.interpolate( + x, scale_factor=2, mode='bilinear', align_corners=False) x = paddle.concat([inputs_prev, x], axis=1) x = self.conv1(x) x = self.bn1(x) @@ -80,18 +109,18 @@ class UNet(nn.Layer): self.down4 = Encoder(num_channels=64, num_filters=128) self.down5 = Encoder(num_channels=128, num_filters=256) - self.up4 = Decoder(256, 128) - self.up3 = Decoder(128, 64) - self.up2 = Decoder(64, 32) self.up1 = Decoder(32, 16) + self.up2 = Decoder(64, 32) + self.up3 = Decoder(128, 64) + self.up4 = Decoder(256, 128) self.out_channels = 16 def forward(self, inputs): - x1, x = self.down1(inputs) - x2, x = self.down2(x) - x3, x = self.down3(x) - x4, x = self.down4(x) - x5, x = self.down5(x) + x1, _ = self.down1(inputs) + _, x2 = self.down2(x1) + _, x3 = self.down3(x2) + _, x4 = self.down4(x3) + _, x5 = self.down5(x4) x = self.up4(x4, x5) x = self.up3(x3, x) @@ -117,10 +146,13 @@ class Kie_backbone(nn.Layer): rois_num = paddle.to_tensor(rois_num, dtype='int32') return rois, rois_num - def pre_process(self, relations, texts, gt_bboxes, tag): - relations, texts, gt_bboxes, tag = relations.numpy(), texts.numpy( - ), gt_bboxes.numpy(), tag.numpy().tolist() + def pre_process(self, img, relations, texts, gt_bboxes, tag, img_size): + img, relations, texts, gt_bboxes, tag, img_size = img.numpy( + ), relations.numpy(), texts.numpy(), gt_bboxes.numpy(), tag.numpy( + ).tolist(), img_size.numpy() temp_relations, temp_texts, temp_gt_bboxes = [], [], [] + h, w = int(np.max(img_size[:, 0])), int(np.max(img_size[:, 1])) + img = paddle.to_tensor(img[:, :, :h, :w]) batch = len(tag) for i in range(batch): num, recoder_len = tag[i][0], tag[i][1] @@ -133,13 +165,22 @@ class Kie_backbone(nn.Layer): temp_gt_bboxes.append( paddle.to_tensor( gt_bboxes[i, :num, ...], dtype='float32')) - return temp_relations, temp_texts, temp_gt_bboxes + return img, temp_relations, temp_texts, temp_gt_bboxes def forward(self, inputs): - img, relations, texts, gt_bboxes, tag = inputs[0], inputs[1], inputs[ - 2], inputs[3], inputs[5] - relations, texts, gt_bboxes = self.pre_process(relations, texts, - gt_bboxes, tag) + img, relations, texts, gt_bboxes, tag, img_size = inputs[0], inputs[ + 1], inputs[2], inputs[3], inputs[5], inputs[-1] + img, relations, texts, gt_bboxes = self.pre_process( + img, relations, texts, gt_bboxes, tag, img_size) + # for i in range(4): + # img_t = (img[i].numpy().transpose([1, 2, 0]) * 255.0).astype('uint8') + # img_t = img_t.copy() + # gt_bboxes_t = gt_bboxes[i].cpu().numpy() + # box = gt_bboxes_t.astype(np.int32).reshape((-1, 1, 2)) + # cv2.polylines(img_t, [box], True, color=(255, 255, 0), thickness=1) + # cv2.imwrite("/Users/hongyongjie/project/PaddleOCR/output/{}.png".format(i), img_t) + # # cv2.imwrite("/Users/hongyongjie/project/PaddleOCR/output/{}.png".format(i), img_t * 255.0) + # exit() x = self.img_feat(img) boxes, rois_num = self.bbox2roi(gt_bboxes) feats = paddle.fluid.layers.roi_align(