diff --git a/configs/kie/kie_unet_sdmgr.yml b/configs/kie/kie_unet_sdmgr.yml new file mode 100644 index 0000000000000000000000000000000000000000..55d07be26c11b1542c78a71bd9e79d4509637652 --- /dev/null +++ b/configs/kie/kie_unet_sdmgr.yml @@ -0,0 +1,111 @@ +Global: + use_gpu: True + epoch_num: 300 + log_smooth_window: 20 + print_batch_step: 50 + save_model_dir: ./output/kie_5/ + save_epoch_step: 50 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [ 0, 80 ] + # 1. If pretrained_model is saved in static mode, such as classification pretrained model + # from static branch, load_static_weights must be set as True. + # 2. If you want to finetune the pretrained models we provide in the docs, + # you should set load_static_weights as False. + load_static_weights: False + cal_metric_during_train: False + pretrained_model: ./output/kie_4/best_accuracy + checkpoints: + save_inference_dir: + use_visualdl: False + class_path: ./train_data/wildreceipt/class_list.txt + infer_img: ./train_data/wildreceipt/1.txt + save_res_path: ./output/sdmgr_kie/predicts_kie.txt + img_scale: [ 1024, 512 ] + +Architecture: + model_type: kie + algorithm: SDMGR + Transform: + Backbone: + name: Kie_backbone + Head: + name: SDMGRHead + +Loss: + name: SDMGRLoss + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Piecewise + learning_rate: 0.001 + decay_epochs: [ 60, 80, 100] + values: [ 0.001, 0.0001, 0.00001] + warmup_epoch: 2 + regularizer: + name: 'L2' + factor: 0.00005 + +PostProcess: + name: None + +Metric: + name: KIEMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/wildreceipt/ + label_file_list: [ './train_data/wildreceipt/wildreceipt_train.txt' ] + ratio_list: [ 1.0 ] + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - NormalizeImage: + scale: 1 + mean: [ 123.675, 116.28, 103.53 ] + std: [ 58.395, 57.12, 57.375 ] + order: 'hwc' + - KieLabelEncode: # Class handling label + character_dict_path: ./train_data/wildreceipt/dict.txt + - KieResize: + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'image', 'relations', 'texts', 'points', 'labels', 'tag', 'shape'] # dataloader will return list in this order + loader: + shuffle: True + drop_last: False + batch_size_per_card: 4 + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/wildreceipt + label_file_list: + - ./train_data/wildreceipt/wildreceipt_test.txt + # - /paddle/data/PaddleOCR/train_data/wildreceipt/1.txt + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - KieLabelEncode: # Class handling label + character_dict_path: ./train_data/wildreceipt/dict.txt + - KieResize: + - NormalizeImage: + scale: 1 + mean: [ 123.675, 116.28, 103.53 ] + std: [ 58.395, 57.12, 57.375 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'image', 'relations', 'texts', 'points', 'labels', 'tag', 'ori_image', 'ori_boxes', 'shape'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 4 \ No newline at end of file diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index f0c16618c0dd0b0f0bcc6a06d6b142a59d58e725..494392029fbcca1df336225e66df3d6aca3ad1f1 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -21,6 +21,7 @@ PaddleOCR开源的文本检测算法列表: - [x] EAST([paper](https://arxiv.org/abs/1704.03155))[1] - [x] SAST([paper](https://arxiv.org/abs/1908.05498))[4] - [x] PSENet([paper](https://arxiv.org/abs/1903.12473v2)) +- [x] SDMGR([paper](https://arxiv.org/pdf/2103.14470.pdf)) 在ICDAR2015文本检测公开数据集上,算法效果如下: |模型|骨干网络|precision|recall|Hmean|下载链接| @@ -32,6 +33,7 @@ PaddleOCR开源的文本检测算法列表: |SAST|ResNet50_vd|91.39%|83.77%|87.42%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)| |PSE|ResNet50_vd|85.81%|79.53%|82.55%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)| |PSE|MobileNetV3|82.20%|70.48%|75.89%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)| +|SDMGR|VGG16|-|-|87.11%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/kie_vgg16.tar)| 在Total-text文本检测公开数据集上,算法效果如下: diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 0a4fad621a9038e71a9d43eb4e12f78e7e92d73d..344f596d9d7c435ec98074de6d8cc08a420d27e8 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -19,6 +19,7 @@ from __future__ import unicode_literals import numpy as np import string +from shapely.geometry import LineString, Point, Polygon import json from ppocr.utils.logging import get_logger @@ -286,6 +287,168 @@ class E2ELabelEncodeTrain(object): return data +class KieLabelEncode(object): + def __init__(self, character_dict_path, norm=10, directed=False, **kwargs): + super(KieLabelEncode, self).__init__() + self.dict = dict({'': 0}) + with open(character_dict_path, 'r') as fr: + idx = 1 + for line in fr: + char = line.strip() + self.dict[char] = idx + idx += 1 + self.norm = norm + self.directed = directed + + def compute_relation(self, boxes): + """Compute relation between every two boxes.""" + x1s, y1s = boxes[:, 0:1], boxes[:, 1:2] + x2s, y2s = boxes[:, 4:5], boxes[:, 5:6] + ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1) + dxs = (x1s[:, 0][None] - x1s) / self.norm + dys = (y1s[:, 0][None] - y1s) / self.norm + xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs + whs = ws / hs + np.zeros_like(xhhs) + relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1) + bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32) + return relations, bboxes + + def pad_text_indices(self, text_inds): + """Pad text index to same length.""" + max_len = 300 + recoder_len = max([len(text_ind) for text_ind in text_inds]) + padded_text_inds = -np.ones((len(text_inds), max_len), np.int32) + for idx, text_ind in enumerate(text_inds): + padded_text_inds[idx, :len(text_ind)] = np.array(text_ind) + return padded_text_inds, recoder_len + + def list_to_numpy(self, ann_infos): + """Convert bboxes, relations, texts and labels to ndarray.""" + boxes, text_inds = ann_infos['points'], ann_infos['text_inds'] + boxes = np.array(boxes, np.int32) + relations, bboxes = self.compute_relation(boxes) + + labels = ann_infos.get('labels', None) + if labels is not None: + labels = np.array(labels, np.int32) + edges = ann_infos.get('edges', None) + if edges is not None: + labels = labels[:, None] + edges = np.array(edges) + edges = (edges[:, None] == edges[None, :]).astype(np.int32) + if self.directed: + edges = (edges & labels == 1).astype(np.int32) + np.fill_diagonal(edges, -1) + labels = np.concatenate([labels, edges], -1) + padded_text_inds, recoder_len = self.pad_text_indices(text_inds) + max_num = 300 + temp_bboxes = np.zeros([max_num, 4]) + h, _ = bboxes.shape + temp_bboxes[:h, :h] = bboxes + + temp_relations = np.zeros([max_num, max_num, 5]) + temp_relations[:h, :h, :] = relations + + temp_padded_text_inds = np.zeros([max_num, max_num]) + temp_padded_text_inds[:h, :] = padded_text_inds + + temp_labels = np.zeros([max_num, max_num]) + temp_labels[:h, :h + 1] = labels + + tag = np.array([h, recoder_len]) + return dict( + image=ann_infos['image'], + points=temp_bboxes, + relations=temp_relations, + texts=temp_padded_text_inds, + labels=temp_labels, + tag=tag) + + def convert_canonical(self, points_x, points_y): + + assert len(points_x) == 4 + assert len(points_y) == 4 + + points = [Point(points_x[i], points_y[i]) for i in range(4)] + + polygon = Polygon([(p.x, p.y) for p in points]) + min_x, min_y, _, _ = polygon.bounds + points_to_lefttop = [ + LineString([points[i], Point(min_x, min_y)]) for i in range(4) + ] + distances = np.array([line.length for line in points_to_lefttop]) + sort_dist_idx = np.argsort(distances) + lefttop_idx = sort_dist_idx[0] + + if lefttop_idx == 0: + point_orders = [0, 1, 2, 3] + elif lefttop_idx == 1: + point_orders = [1, 2, 3, 0] + elif lefttop_idx == 2: + point_orders = [2, 3, 0, 1] + else: + point_orders = [3, 0, 1, 2] + + sorted_points_x = [points_x[i] for i in point_orders] + sorted_points_y = [points_y[j] for j in point_orders] + + return sorted_points_x, sorted_points_y + + def sort_vertex(self, points_x, points_y): + + assert len(points_x) == 4 + assert len(points_y) == 4 + + x = np.array(points_x) + y = np.array(points_y) + center_x = np.sum(x) * 0.25 + center_y = np.sum(y) * 0.25 + + x_arr = np.array(x - center_x) + y_arr = np.array(y - center_y) + + angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi + sort_idx = np.argsort(angle) + + sorted_points_x, sorted_points_y = [], [] + for i in range(4): + sorted_points_x.append(points_x[sort_idx[i]]) + sorted_points_y.append(points_y[sort_idx[i]]) + + return self.convert_canonical(sorted_points_x, sorted_points_y) + + def __call__(self, data): + import json + label = data['label'] + annotations = json.loads(label) + boxes, texts, text_inds, labels, edges = [], [], [], [], [] + for ann in annotations: + box = ann['points'] + x_list = [box[i][0] for i in range(4)] + y_list = [box[i][1] for i in range(4)] + sorted_x_list, sorted_y_list = self.sort_vertex(x_list, y_list) + sorted_box = [] + for x, y in zip(sorted_x_list, sorted_y_list): + sorted_box.append(x) + sorted_box.append(y) + boxes.append(sorted_box) + text = ann['transcription'] + texts.append(ann['transcription']) + text_ind = [self.dict[c] for c in text if c in self.dict] + text_inds.append(text_ind) + labels.append(ann['label']) + edges.append(ann.get('edge', 0)) + ann_infos = dict( + image=data['image'], + points=boxes, + texts=texts, + text_inds=text_inds, + edges=edges, + labels=labels) + + return self.list_to_numpy(ann_infos) + + class AttnLabelEncode(BaseRecLabelEncode): """ Convert between text-label and text-index """ diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index 87e3088d07a8c5a2eea5d4deff87c69a753e215b..c3dfd316f86d88b5c7fd52eb6ae23d22a4dd32eb 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -111,7 +111,6 @@ class NormalizeImage(object): from PIL import Image if isinstance(img, Image.Image): img = np.array(img) - assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage" data['image'] = ( @@ -367,3 +366,53 @@ class E2EResizeForTest(object): ratio_w = resize_w / float(w) return im, (ratio_h, ratio_w) + + +class KieResize(object): + def __init__(self, **kwargs): + super(KieResize, self).__init__() + self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[ + 'img_scale'][1] + + def __call__(self, data): + img = data['image'] + points = data['points'] + src_h, src_w, _ = img.shape + im_resized, scale_factor, [ratio_h, ratio_w + ], [new_h, new_w] = self.resize_image(img) + resize_points = self.resize_boxes(img, points, scale_factor) + data['ori_image'] = img + data['ori_boxes'] = points + data['points'] = resize_points + data['image'] = im_resized + data['shape'] = np.array([new_h, new_w]) + return data + + def resize_image(self, img): + norm_img = np.zeros([1024, 1024, 3], dtype='float32') + scale = [512, 1024] + h, w = img.shape[:2] + max_long_edge = max(scale) + max_short_edge = min(scale) + scale_factor = min(max_long_edge / max(h, w), + max_short_edge / min(h, w)) + resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float( + scale_factor) + 0.5) + max_stride = 32 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(img, (resize_w, resize_h)) + new_h, new_w = im.shape[:2] + w_scale = new_w / w + h_scale = new_h / h + scale_factor = np.array( + [w_scale, h_scale, w_scale, h_scale], dtype=np.float32) + norm_img[:new_h, :new_w, :] = im + return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w] + + def resize_boxes(self, im, points, scale_factor): + points = points * scale_factor + img_shape = im.shape[:2] + points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1]) + points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0]) + return points diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index f3f4cd49332b605ec3a0e65e688d965fd91a5cdf..62ad2b6ad86edf9b5446aea03f9333f9d4981336 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -35,6 +35,7 @@ from .cls_loss import ClsLoss # e2e loss from .e2e_pg_loss import PGLoss +from .kie_sdmgr_loss import SDMGRLoss # basic loss function from .basic_loss import DistanceLoss @@ -50,7 +51,7 @@ def build_loss(config): support_dict = [ 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', - 'TableAttentionLoss', 'SARLoss', 'AsterLoss' + 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/kie_sdmgr_loss.py b/ppocr/losses/kie_sdmgr_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2173e49904926ebab2c450890c4fafe3f36b50 --- /dev/null +++ b/ppocr/losses/kie_sdmgr_loss.py @@ -0,0 +1,113 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import nn +import paddle + + +class SDMGRLoss(nn.Layer): + def __init__(self, node_weight=1.0, edge_weight=1.0, ignore=0): + super().__init__() + self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore) + self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1) + self.node_weight = node_weight + self.edge_weight = edge_weight + self.ignore = ignore + + def pre_process(self, gts, tag): + gts, tag = gts.numpy(), tag.numpy().tolist() + temp_gts = [] + batch = len(tag) + for i in range(batch): + num, recoder_len = tag[i][0], tag[i][1] + temp_gts.append( + paddle.to_tensor( + gts[i, :num, :num + 1], dtype='int64')) + return temp_gts + + def accuracy(self, pred, target, topk=1, thresh=None): + """Calculate accuracy according to the prediction and target. + + Args: + pred (torch.Tensor): The model prediction, shape (N, num_class) + target (torch.Tensor): The target of each prediction, shape (N, ) + topk (int | tuple[int], optional): If the predictions in ``topk`` + matches the target, the predictions will be regarded as + correct ones. Defaults to 1. + thresh (float, optional): If not None, predictions with scores under + this threshold are considered incorrect. Default to None. + + Returns: + float | tuple[float]: If the input ``topk`` is a single integer, + the function will return a single float as accuracy. If + ``topk`` is a tuple containing multiple integers, the + function will return a tuple containing accuracies of + each ``topk`` number. + """ + assert isinstance(topk, (int, tuple)) + if isinstance(topk, int): + topk = (topk, ) + return_single = True + else: + return_single = False + + maxk = max(topk) + if pred.shape[0] == 0: + accu = [pred.new_tensor(0.) for i in range(len(topk))] + return accu[0] if return_single else accu + pred_value, pred_label = paddle.topk(pred, maxk, axis=1) + pred_label = pred_label.transpose( + [1, 0]) # transpose to shape (maxk, N) + correct = paddle.equal(pred_label, + (target.reshape([1, -1]).expand_as(pred_label))) + res = [] + for k in topk: + correct_k = paddle.sum(correct[:k].reshape([-1]).astype('float32'), + axis=0, + keepdim=True) + res.append( + paddle.multiply(correct_k, + paddle.to_tensor(100.0 / pred.shape[0]))) + return res[0] if return_single else res + + def forward(self, pred, batch): + node_preds, edge_preds = pred + gts, tag = batch[4], batch[5] + gts = self.pre_process(gts, tag) + node_gts, edge_gts = [], [] + for gt in gts: + node_gts.append(gt[:, 0]) + edge_gts.append(gt[:, 1:].reshape([-1])) + node_gts = paddle.concat(node_gts) + edge_gts = paddle.concat(edge_gts) + + node_valids = paddle.nonzero(node_gts != self.ignore).reshape([-1]) + edge_valids = paddle.nonzero(edge_gts != -1).reshape([-1]) + loss_node = self.loss_node(node_preds, node_gts) + loss_edge = self.loss_edge(edge_preds, edge_gts) + loss = self.node_weight * loss_node + self.edge_weight * loss_edge + return dict( + loss=loss, + loss_node=loss_node, + loss_edge=loss_edge, + acc_node=self.accuracy( + paddle.gather(node_preds, node_valids), + paddle.gather(node_gts, node_valids)), + acc_edge=self.accuracy( + paddle.gather(edge_preds, edge_valids), + paddle.gather(edge_gts, edge_valids))) diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index 64f62e51cdf922773c03bb784a4edffdc17f506f..28bff3cb4eb7784db876940f761208f1b084f0e2 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -27,10 +27,13 @@ from .cls_metric import ClsMetric from .e2e_metric import E2EMetric from .distillation_metric import DistillationMetric from .table_metric import TableMetric +from .kie_metric import KIEMetric + def build_metric(config): support_dict = [ - "DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric" + "DetMetric", "RecMetric", "ClsMetric", "E2EMetric", + "DistillationMetric", "TableMetric", 'KIEMetric' ] config = copy.deepcopy(config) diff --git a/ppocr/metrics/kie_metric.py b/ppocr/metrics/kie_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..761965cfcc25d2a6de30342769d01b36d6212d98 --- /dev/null +++ b/ppocr/metrics/kie_metric.py @@ -0,0 +1,70 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle + +__all__ = ['KIEMetric'] + + +class KIEMetric(object): + def __init__(self, main_indicator='hmean', **kwargs): + self.main_indicator = main_indicator + self.reset() + self.node = [] + self.gt = [] + + def __call__(self, preds, batch, **kwargs): + nodes, _ = preds + gts, tag = batch[4].squeeze(0), batch[5].tolist()[0] + gts = gts[:tag[0], :1].reshape([-1]) + self.node.append(nodes.numpy()) + self.gt.append(gts) + # result = self.compute_f1_score(nodes, gts) + # self.results.append(result) + + def compute_f1_score(self, preds, gts): + ignores = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25] + C = preds.shape[1] + classes = np.array(sorted(set(range(C)) - set(ignores))) + hist = np.bincount( + (gts * C).astype('int64') + preds.argmax(1), minlength=C + **2).reshape([C, C]).astype('float32') + diag = np.diag(hist) + recalls = diag / hist.sum(1).clip(min=1) + precisions = diag / hist.sum(0).clip(min=1) + f1 = 2 * recalls * precisions / (recalls + precisions).clip(min=1e-8) + return f1[classes] + + def combine_results(self, results): + node = np.concatenate(self.node, 0) + gts = np.concatenate(self.gt, 0) + results = self.compute_f1_score(node, gts) + data = {'hmean': results.mean()} + return data + + def get_metric(self): + + metircs = self.combine_results(self.results) + self.reset() + return metircs + + def reset(self): + self.results = [] # clear results + self.node = [] + self.gt = [] diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 66b507fd24158ddf64d68dd7392f828a2e17c399..d10983487bedb0fc4278095db08d1f234ef5c595 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -35,7 +35,14 @@ def build_backbone(config, model_type): ] elif model_type == "e2e": from .e2e_resnet_vd_pg import ResNet - support_dict = ["ResNet"] + support_dict = ['ResNet'] + elif model_type == 'kie': + from .kie_unet_sdmgr import Kie_backbone + support_dict = ['Kie_backbone'] + elif model_type == "table": + from .table_resnet_vd import ResNet + from .table_mobilenet_v3 import MobileNetV3 + support_dict = ["ResNet", "MobileNetV3"] else: raise NotImplementedError diff --git a/ppocr/modeling/backbones/kie_unet_sdmgr.py b/ppocr/modeling/backbones/kie_unet_sdmgr.py new file mode 100644 index 0000000000000000000000000000000000000000..545e4e7511e58c3d8220e9ec0be35474deba8806 --- /dev/null +++ b/ppocr/modeling/backbones/kie_unet_sdmgr.py @@ -0,0 +1,186 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import numpy as np +import cv2 + +__all__ = ["Kie_backbone"] + + +class Encoder(nn.Layer): + def __init__(self, num_channels, num_filters): + super(Encoder, self).__init__() + self.conv1 = nn.Conv2D( + num_channels, + num_filters, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False) + self.bn1 = nn.BatchNorm(num_filters, act='relu') + + self.conv2 = nn.Conv2D( + num_filters, + num_filters, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False) + self.bn2 = nn.BatchNorm(num_filters, act='relu') + + self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + + def forward(self, inputs): + x = self.conv1(inputs) + x = self.bn1(x) + x = self.conv2(x) + x = self.bn2(x) + x_pooled = self.pool(x) + return x, x_pooled + + +class Decoder(nn.Layer): + def __init__(self, num_channels, num_filters): + super(Decoder, self).__init__() + + self.conv1 = nn.Conv2D( + num_channels, + num_filters, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False) + self.bn1 = nn.BatchNorm(num_filters, act='relu') + + self.conv2 = nn.Conv2D( + num_filters, + num_filters, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False) + self.bn2 = nn.BatchNorm(num_filters, act='relu') + + self.conv0 = nn.Conv2D( + num_channels, + num_filters, + kernel_size=1, + stride=1, + padding=0, + bias_attr=False) + self.bn0 = nn.BatchNorm(num_filters, act='relu') + + def forward(self, inputs_prev, inputs): + x = self.conv0(inputs) + x = self.bn0(x) + x = paddle.nn.functional.interpolate( + x, scale_factor=2, mode='bilinear', align_corners=False) + x = paddle.concat([inputs_prev, x], axis=1) + x = self.conv1(x) + x = self.bn1(x) + x = self.conv2(x) + x = self.bn2(x) + return x + + +class UNet(nn.Layer): + def __init__(self): + super(UNet, self).__init__() + self.down1 = Encoder(num_channels=3, num_filters=16) + self.down2 = Encoder(num_channels=16, num_filters=32) + self.down3 = Encoder(num_channels=32, num_filters=64) + self.down4 = Encoder(num_channels=64, num_filters=128) + self.down5 = Encoder(num_channels=128, num_filters=256) + + self.up1 = Decoder(32, 16) + self.up2 = Decoder(64, 32) + self.up3 = Decoder(128, 64) + self.up4 = Decoder(256, 128) + self.out_channels = 16 + + def forward(self, inputs): + x1, _ = self.down1(inputs) + _, x2 = self.down2(x1) + _, x3 = self.down3(x2) + _, x4 = self.down4(x3) + _, x5 = self.down5(x4) + + x = self.up4(x4, x5) + x = self.up3(x3, x) + x = self.up2(x2, x) + x = self.up1(x1, x) + return x + + +class Kie_backbone(nn.Layer): + def __init__(self, in_channels, **kwargs): + super(Kie_backbone, self).__init__() + self.out_channels = 16 + self.img_feat = UNet() + self.maxpool = nn.MaxPool2D(kernel_size=7) + + def bbox2roi(self, bbox_list): + rois_list = [] + rois_num = [] + for img_id, bboxes in enumerate(bbox_list): + rois_num.append(bboxes.shape[0]) + rois_list.append(bboxes) + rois = paddle.concat(rois_list, 0) + rois_num = paddle.to_tensor(rois_num, dtype='int32') + return rois, rois_num + + def pre_process(self, img, relations, texts, gt_bboxes, tag, img_size): + img, relations, texts, gt_bboxes, tag, img_size = img.numpy( + ), relations.numpy(), texts.numpy(), gt_bboxes.numpy(), tag.numpy( + ).tolist(), img_size.numpy() + temp_relations, temp_texts, temp_gt_bboxes = [], [], [] + h, w = int(np.max(img_size[:, 0])), int(np.max(img_size[:, 1])) + img = paddle.to_tensor(img[:, :, :h, :w]) + batch = len(tag) + for i in range(batch): + num, recoder_len = tag[i][0], tag[i][1] + temp_relations.append( + paddle.to_tensor( + relations[i, :num, :num, :], dtype='float32')) + temp_texts.append( + paddle.to_tensor( + texts[i, :num, :recoder_len], dtype='float32')) + temp_gt_bboxes.append( + paddle.to_tensor( + gt_bboxes[i, :num, ...], dtype='float32')) + return img, temp_relations, temp_texts, temp_gt_bboxes + + def forward(self, inputs): + img = inputs[0] + relations, texts, gt_bboxes, tag, img_size = inputs[1], inputs[ + 2], inputs[3], inputs[5], inputs[-1] + img, relations, texts, gt_bboxes = self.pre_process( + img, relations, texts, gt_bboxes, tag, img_size) + x = self.img_feat(img) + boxes, rois_num = self.bbox2roi(gt_bboxes) + feats = paddle.fluid.layers.roi_align( + x, + boxes, + spatial_scale=1.0, + pooled_height=7, + pooled_width=7, + rois_num=rois_num) + feats = self.maxpool(feats).squeeze(-1).squeeze(-1) + return [relations, texts, feats] diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index fdadfed5e3fe30b6bd311a07d6ba36869f175488..4a27ce52a64da5a53d524f58d7613669171d5662 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -33,14 +33,19 @@ def build_head(config): # cls head from .cls_head import ClsHead + + #kie head + from .kie_sdmgr_head import SDMGRHead + + from .table_att_head import TableAttentionHead + support_dict = [ 'DBHead', 'PSEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', - 'TableAttentionHead', 'SARHead', 'AsterHead' + 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead' ] #table head - from .table_att_head import TableAttentionHead module_name = config.pop('name') assert module_name in support_dict, Exception('head only support {}'.format( diff --git a/ppocr/modeling/heads/kie_sdmgr_head.py b/ppocr/modeling/heads/kie_sdmgr_head.py new file mode 100644 index 0000000000000000000000000000000000000000..46ac0ed8dcaccb7628ef87fbe851a2b6acd60d55 --- /dev/null +++ b/ppocr/modeling/heads/kie_sdmgr_head.py @@ -0,0 +1,206 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + + +class SDMGRHead(nn.Layer): + def __init__(self, + in_channels, + num_chars=92, + visual_dim=16, + fusion_dim=1024, + node_input=32, + node_embed=256, + edge_input=5, + edge_embed=256, + num_gnn=2, + num_classes=26, + bidirectional=False): + super().__init__() + + self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim) + self.node_embed = nn.Embedding(num_chars, node_input, 0) + hidden = node_embed // 2 if bidirectional else node_embed + self.rnn = nn.LSTM( + input_size=node_input, hidden_size=hidden, num_layers=1) + self.edge_embed = nn.Linear(edge_input, edge_embed) + self.gnn_layers = nn.LayerList( + [GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)]) + self.node_cls = nn.Linear(node_embed, num_classes) + self.edge_cls = nn.Linear(edge_embed, 2) + + def forward(self, input, targets): + relations, texts, x = input + node_nums, char_nums = [], [] + for text in texts: + node_nums.append(text.shape[0]) + char_nums.append(paddle.sum((text > -1).astype(int), axis=-1)) + + max_num = max([char_num.max() for char_num in char_nums]) + all_nodes = paddle.concat([ + paddle.concat( + [text, paddle.zeros( + (text.shape[0], max_num - text.shape[1]))], -1) + for text in texts + ]) + temp = paddle.clip(all_nodes, min=0).astype(int) + embed_nodes = self.node_embed(temp) + rnn_nodes, _ = self.rnn(embed_nodes) + + b, h, w = rnn_nodes.shape + nodes = paddle.zeros([b, w]) + all_nums = paddle.concat(char_nums) + valid = paddle.nonzero((all_nums > 0).astype(int)) + temp_all_nums = ( + paddle.gather(all_nums, valid) - 1).unsqueeze(-1).unsqueeze(-1) + temp_all_nums = paddle.expand(temp_all_nums, [ + temp_all_nums.shape[0], temp_all_nums.shape[1], rnn_nodes.shape[-1] + ]) + temp_all_nodes = paddle.gather(rnn_nodes, valid) + N, C, A = temp_all_nodes.shape + one_hot = F.one_hot( + temp_all_nums[:, 0, :], num_classes=C).transpose([0, 2, 1]) + one_hot = paddle.multiply( + temp_all_nodes, one_hot.astype("float32")).sum(axis=1, keepdim=True) + t = one_hot.expand([N, 1, A]).squeeze(1) + nodes = paddle.scatter(nodes, valid.squeeze(1), t) + + if x is not None: + nodes = self.fusion([x, nodes]) + + all_edges = paddle.concat( + [rel.reshape([-1, rel.shape[-1]]) for rel in relations]) + embed_edges = self.edge_embed(all_edges.astype('float32')) + embed_edges = F.normalize(embed_edges) + + for gnn_layer in self.gnn_layers: + nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums) + + node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes) + return node_cls, edge_cls + + +class GNNLayer(nn.Layer): + def __init__(self, node_dim=256, edge_dim=256): + super().__init__() + self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim) + self.coef_fc = nn.Linear(node_dim, 1) + self.out_fc = nn.Linear(node_dim, node_dim) + self.relu = nn.ReLU() + + def forward(self, nodes, edges, nums): + start, cat_nodes = 0, [] + for num in nums: + sample_nodes = nodes[start:start + num] + cat_nodes.append( + paddle.concat([ + paddle.expand(sample_nodes.unsqueeze(1), [-1, num, -1]), + paddle.expand(sample_nodes.unsqueeze(0), [num, -1, -1]) + ], -1).reshape([num**2, -1])) + start += num + cat_nodes = paddle.concat([paddle.concat(cat_nodes), edges], -1) + cat_nodes = self.relu(self.in_fc(cat_nodes)) + coefs = self.coef_fc(cat_nodes) + + start, residuals = 0, [] + for num in nums: + residual = F.softmax( + -paddle.eye(num).unsqueeze(-1) * 1e9 + + coefs[start:start + num**2].reshape([num, num, -1]), 1) + residuals.append((residual * cat_nodes[start:start + num**2] + .reshape([num, num, -1])).sum(1)) + start += num**2 + + nodes += self.relu(self.out_fc(paddle.concat(residuals))) + return [nodes, cat_nodes] + + +class Block(nn.Layer): + def __init__(self, + input_dims, + output_dim, + mm_dim=1600, + chunks=20, + rank=15, + shared=False, + dropout_input=0., + dropout_pre_lin=0., + dropout_output=0., + pos_norm='before_cat'): + super().__init__() + self.rank = rank + self.dropout_input = dropout_input + self.dropout_pre_lin = dropout_pre_lin + self.dropout_output = dropout_output + assert (pos_norm in ['before_cat', 'after_cat']) + self.pos_norm = pos_norm + # Modules + self.linear0 = nn.Linear(input_dims[0], mm_dim) + self.linear1 = (self.linear0 + if shared else nn.Linear(input_dims[1], mm_dim)) + self.merge_linears0 = nn.LayerList() + self.merge_linears1 = nn.LayerList() + self.chunks = self.chunk_sizes(mm_dim, chunks) + for size in self.chunks: + ml0 = nn.Linear(size, size * rank) + self.merge_linears0.append(ml0) + ml1 = ml0 if shared else nn.Linear(size, size * rank) + self.merge_linears1.append(ml1) + self.linear_out = nn.Linear(mm_dim, output_dim) + + def forward(self, x): + x0 = self.linear0(x[0]) + x1 = self.linear1(x[1]) + bs = x1.shape[0] + if self.dropout_input > 0: + x0 = F.dropout(x0, p=self.dropout_input, training=self.training) + x1 = F.dropout(x1, p=self.dropout_input, training=self.training) + x0_chunks = paddle.split(x0, self.chunks, -1) + x1_chunks = paddle.split(x1, self.chunks, -1) + zs = [] + for x0_c, x1_c, m0, m1 in zip(x0_chunks, x1_chunks, self.merge_linears0, + self.merge_linears1): + m = m0(x0_c) * m1(x1_c) # bs x split_size*rank + m = m.reshape([bs, self.rank, -1]) + z = paddle.sum(m, 1) + if self.pos_norm == 'before_cat': + z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z)) + z = F.normalize(z) + zs.append(z) + z = paddle.concat(zs, 1) + if self.pos_norm == 'after_cat': + z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z)) + z = F.normalize(z) + + if self.dropout_pre_lin > 0: + z = F.dropout(z, p=self.dropout_pre_lin, training=self.training) + z = self.linear_out(z) + if self.dropout_output > 0: + z = F.dropout(z, p=self.dropout_output, training=self.training) + return z + + def chunk_sizes(self, dim, chunks): + split_size = (dim + chunks - 1) // chunks + sizes_list = [split_size] * chunks + sizes_list[-1] = sizes_list[-1] - (sum(sizes_list) - dim) + return sizes_list diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index c6cb0144f7efd9ff7976ad67a658a554eafce754..37dadd12d3f628b1802b6a31f611f49f3ac600c2 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -45,6 +45,8 @@ def build_post_process(config, global_config=None): config = copy.deepcopy(config) module_name = config.pop('name') + if module_name == "None": + return if global_config is not None: config.update(global_config) assert module_name in support_dict, Exception( diff --git a/requirements.txt b/requirements.txt index 0c87c5c95069a2699f5a3a50320c883c6118ffe7..9900588b25df99e0853ec4521f0632578c55f530 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ cython lxml premailer openpyxl -fasttext==0.9.1 \ No newline at end of file +fasttext==0.9.1 + diff --git a/tools/eval.py b/tools/eval.py index c85490a316772e9dfdfe3267087ea3946a2a3b72..13a4a0882f5a20b47e8999042713e1623b32ff5a 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -54,7 +54,8 @@ def main(): config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) - extra_input = config['Architecture']['algorithm'] in ["SRN", "SAR"] + extra_input = config['Architecture'][ + 'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"] if "model_type" in config['Architecture'].keys(): model_type = config['Architecture']['model_type'] else: @@ -68,7 +69,6 @@ def main(): # build metric eval_class = build_metric(config['Metric']) - # start eval metric = program.eval(model, valid_dataloader, post_process_class, eval_class, model_type, extra_input) diff --git a/tools/infer_kie.py b/tools/infer_kie.py new file mode 100755 index 0000000000000000000000000000000000000000..62ef697240ffe89fcb858c5308bd010105dde2ab --- /dev/null +++ b/tools/infer_kie.py @@ -0,0 +1,139 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle.nn.functional as F + +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import paddle + +from ppocr.data import create_operators, transform +from ppocr.modeling.architectures import build_model +from ppocr.utils.save_load import init_model +import tools.program as program + + +def read_class_list(filepath): + dict = {} + with open(filepath, "r") as f: + lines = f.readlines() + for line in lines: + key, value = line.split(" ") + dict[key] = value.rstrip() + return dict + + +def draw_kie_result(batch, node, idx_to_cls, count): + img = batch[6].copy() + boxes = batch[7] + h, w = img.shape[:2] + pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255 + max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1) + node_pred_label = max_idx.numpy().tolist() + node_pred_score = max_value.numpy().tolist() + + for i, box in enumerate(boxes): + if i >= len(node_pred_label): + break + new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]], + [box[0], box[3]]] + Pts = np.array([new_box], np.int32) + cv2.polylines( + img, [Pts.reshape((-1, 1, 2))], + True, + color=(255, 255, 0), + thickness=1) + x_min = int(min([point[0] for point in new_box])) + y_min = int(min([point[1] for point in new_box])) + + pred_label = str(node_pred_label[i]) + if pred_label in idx_to_cls: + pred_label = idx_to_cls[pred_label] + pred_score = '{:.2f}'.format(node_pred_score[i]) + text = pred_label + '(' + pred_score + ')' + cv2.putText(pred_img, text, (x_min * 2, y_min), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1) + vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255 + vis_img[:, :w] = img + vis_img[:, w:] = pred_img + save_kie_path = os.path.dirname(config['Global']['save_res_path']) + "/kie_results/" + if not os.path.exists(save_kie_path): + os.makedirs(save_kie_path) + save_path = os.path.join(save_kie_path, str(count) + ".png") + cv2.imwrite(save_path, vis_img) + logger.info("The Kie Image saved in {}".format(save_path)) + + +def main(): + global_config = config['Global'] + + # build model + model = build_model(config['Architecture']) + init_model(config, model, logger) + + # create data ops + transforms = [] + for op in config['Eval']['dataset']['transforms']: + transforms.append(op) + + data_dir = config['Eval']['dataset']['data_dir'] + + ops = create_operators(transforms, global_config) + + save_res_path = config['Global']['save_res_path'] + class_path = config['Global']['class_path'] + idx_to_cls = read_class_list(class_path) + if not os.path.exists(os.path.dirname(save_res_path)): + os.makedirs(os.path.dirname(save_res_path)) + + model.eval() + with open(save_res_path, "wb") as fout: + with open(config['Global']['infer_img'], "rb") as f: + lines = f.readlines() + for index, data_line in enumerate(lines): + data_line = data_line.decode('utf-8') + substr = data_line.strip("\n").split("\t") + img_path, label = data_dir + "/" + substr[0], substr[1] + data = {'img_path': img_path, 'label': label} + with open(data['img_path'], 'rb') as f: + img = f.read() + data['image'] = img + batch = transform(data, ops) + batch_pred = [0] * len(batch) + for i in range(len(batch)): + batch_pred[i] = paddle.to_tensor( + np.expand_dims( + batch[i], axis=0)) + node, edge = model(batch_pred) + node = F.softmax(node, -1) + draw_kie_result(batch, node, idx_to_cls, index) + logger.info("success!") + + +if __name__ == '__main__': + config, device, logger, vdl_writer = program.preprocess() + main() diff --git a/tools/program.py b/tools/program.py index d110f70704028948dff2bc889e07d128e0bc94ea..c1547efbcebc5ee8522aa7f190c44d602b595880 100755 --- a/tools/program.py +++ b/tools/program.py @@ -227,6 +227,10 @@ def train(config, images = batch[0] if use_srn: model_average = True + if model_type == 'table' or extra_input: + preds = model(images, data=batch[1:]) + if model_type == "kie": + preds = model(batch) train_start = time.time() # use amp @@ -266,7 +270,7 @@ def train(config, if cal_metric_during_train: # only rec and cls need batch = [item.numpy() for item in batch] - if model_type == 'table': + if model_type in ['table', 'kie']: eval_class(preds, batch) else: post_result = post_process_class(preds, batch[1]) @@ -399,17 +403,20 @@ def eval(model, start = time.time() if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) + if model_type == "kie": + preds = model(batch) else: preds = model(images) batch = [item.numpy() for item in batch] # Obtain usable results from post-processing methods total_time += time.time() - start # Evaluate the results of the current batch - if model_type == 'table': + if model_type in ['table', 'kie']: eval_class(preds, batch) else: post_result = post_process_class(preds, batch[1]) eval_class(post_result, batch) + pbar.update(1) total_frame += len(images) # Get final metric,eg. acc or hmean @@ -498,8 +505,13 @@ def preprocess(is_train=False): assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', - 'SEED' + 'SEED', 'SDMGR' ] + windows_not_support_list = ['PSE'] + if platform.system() == "Windows" and alg in windows_not_support_list: + logger.warning('{} is not support in Windows now'.format( + windows_not_support_list)) + sys.exit() device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = paddle.set_device(device)