提交 ddaa2c25 编写于 作者: 文幕地方's avatar 文幕地方

add SLANet

上级 342522ab
......@@ -32,3 +32,31 @@ paddleocr.egg-info/
/deploy/android_demo/app/cache/
test_tipc/web/models/
test_tipc/web/node_modules/
en_ppocr_mobile_v2.0_table_structure_infer/._inference.pdiparams
en_ppocr_mobile_v2.0_table_structure_infer/._inference.pdiparams.info
en_ppocr_mobile_v2.0_table_structure_infer/._inference.pdmodel
en_ppocr_mobile_v2.0_table_structure_infer/inference.pdiparams
en_ppocr_mobile_v2.0_table_structure_infer/inference.pdiparams.info
en_ppocr_mobile_v2.0_table_structure_infer/inference.pdmodel
ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/._inference.pdiparams
ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/._inference.pdiparams.info
ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/._inference.pdmodel
ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/inference.pdiparams
ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/inference.pdiparams.info
ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/inference.pdmodel
.gitignore
.gitignore
ppstructure/layout/en_ppocr_mobile_v2.0_table_rec_infer/inference.pdiparams
ppstructure/layout/en_ppocr_mobile_v2.0_table_rec_infer/inference.pdiparams.info
ppstructure/layout/en_ppocr_mobile_v2.0_table_rec_infer/inference.pdmodel
ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/infer_cfg.yml
ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/inference.pdiparams
ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/inference.pdiparams.info
ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/inference.pdmodel
.gitignore
ppstructure/layout/table/inference.pdiparams
ppstructure/layout/table/inference.pdiparams.info
ppstructure/layout/table/inference.pdmodel
ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape.tar
._en_ppocr_mobile_v2.0_table_structure_infer
en_ppocr_mobile_v2.0_table_structure_infer.tar
Global:
use_gpu: true
epoch_num: 400
log_smooth_window: 20
print_batch_step: 20
save_model_dir: ./output/SLANet
save_epoch_step: 400
# evaluation is run every 1000 iterations after the 0th iteration
eval_batch_step: [0, 1000]
cal_metric_during_train: True
pretrained_model:
checkpoints: /ssd1/zhoujun20/table/ch/PaddleOCR/output/en/table_lcnet_1_0_csp_pan_headsv3_smooth_l1_pretrain_ssld_weight81_sync_bn/best_accuracy.pdparams
save_inference_dir: ./output/SLANet/infer
use_visualdl: False
infer_img: doc/table/table.jpg
# for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en
max_text_length: &max_text_length 500
box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy'
infer_mode: False
use_sync_bn: True
save_res_path: 'output/infer'
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
clip_norm: 5.0
lr:
# name: Piecewise
learning_rate: 0.001
# decay_epochs : [10, 20]
# values : [0.002, 0.0002, 0.0001]
# warmup_epoch: 0
regularizer:
name: 'L2'
factor: 0.00000
Architecture:
model_type: table
algorithm: SLANet
Backbone:
name: PPLCNet
scale: 1.0
pretrained: true
use_ssld: true
Neck:
name: CSPPAN
out_channels: 96
Head:
name: SLAHead
hidden_size: 256
max_text_length: *max_text_length
loc_reg_num: &loc_reg_num 4
Loss:
name: SLANetLoss
structure_weight: 1.0
loc_weight: 2.0
loc_loss: smooth_l1
PostProcess:
name: TableLabelDecode
Metric:
name: TableMetric
main_indicator: acc
compute_bbox_metric: False
loc_reg_num: *loc_reg_num
box_format: *box_format
Train:
dataset:
name: PubTabDataSet
data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/train/
label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/PubTabNet_2.0.0_train.jsonl]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
- TableBoxEncode:
box_format: *box_format
- ResizeTableImage:
max_len: 488
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- PaddingTableImage:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
loader:
shuffle: True
batch_size_per_card: 48
drop_last: True
num_workers: 1
Eval:
dataset:
name: PubTabDataSet
data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/val/
label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/PubTabNet_2.0.0_val.jsonl]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
- TableBoxEncode:
box_format: *box_format
- ResizeTableImage:
max_len: 488
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- PaddingTableImage:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
loader:
shuffle: False
drop_last: False
batch_size_per_card: 48
num_workers: 1
......@@ -15,9 +15,8 @@ Global:
save_res_path: ./output/table_master
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
infer_mode: false
max_text_length: 500
process_total_num: 0
process_cut_num: 0
max_text_length: &max_text_length 500
box_format: &box_format 'xywh' # 'xywh', 'xyxy', 'xyxyxyxy'
Optimizer:
......@@ -52,7 +51,8 @@ Architecture:
headers: 8
dropout: 0
d_ff: 2024
max_text_length: 500
max_text_length: *max_text_length
loc_reg_num: &loc_reg_num 4
Loss:
name: TableMasterLoss
......@@ -66,6 +66,7 @@ Metric:
name: TableMetric
main_indicator: acc
compute_bbox_metric: False
box_format: *box_format
Train:
dataset:
......@@ -80,13 +81,15 @@ Train:
learn_empty_box: False
merge_no_span_structure: True
replace_empty_cell_token: True
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
box_format: *box_format
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
......@@ -114,13 +117,15 @@ Eval:
learn_empty_box: False
merge_no_span_structure: True
replace_empty_cell_token: True
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
box_format: *box_format
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
......
......@@ -17,10 +17,9 @@ Global:
# for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en
max_text_length: 800
max_text_length: &max_text_length 800
box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy'
infer_mode: False
process_total_num: 0
process_cut_num: 0
Optimizer:
name: Adam
......@@ -44,7 +43,8 @@ Architecture:
name: TableAttentionHead
hidden_size: 256
loc_type: 2
max_text_length: 800
max_text_length: *max_text_length
loc_reg_num: &loc_reg_num 4
Loss:
name: TableAttentionLoss
......@@ -72,6 +72,8 @@ Train:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
- TableBoxEncode:
- ResizeTableImage:
max_len: 488
......@@ -104,6 +106,8 @@ Eval:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
- TableBoxEncode:
- ResizeTableImage:
max_len: 488
......
......@@ -118,7 +118,7 @@ class OCRSystem(hub.Module):
all_results.append([])
continue
starttime = time.time()
dt_boxes, rec_res = self.text_sys(img)
dt_boxes, rec_res, _ = self.text_sys(img)
elapse = time.time() - starttime
logger.info("Predict time: {}".format(elapse))
......
......@@ -571,7 +571,7 @@ class TableLabelEncode(AttnLabelEncode):
replace_empty_cell_token=False,
merge_no_span_structure=False,
learn_empty_box=False,
point_num=2,
loc_reg_num=4,
**kwargs):
self.max_text_len = max_text_length
self.lower = False
......@@ -593,7 +593,7 @@ class TableLabelEncode(AttnLabelEncode):
self.idx2char = {v: k for k, v in self.dict.items()}
self.character = dict_character
self.point_num = point_num
self.loc_reg_num = loc_reg_num
self.pad_idx = self.dict[self.beg_str]
self.start_idx = self.dict[self.beg_str]
self.end_idx = self.dict[self.end_str]
......@@ -649,7 +649,7 @@ class TableLabelEncode(AttnLabelEncode):
# encode box
bboxes = np.zeros(
(self._max_text_len, self.point_num * 2), dtype=np.float32)
(self._max_text_len, self.loc_reg_num), dtype=np.float32)
bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32)
bbox_idx = 0
......@@ -714,11 +714,11 @@ class TableMasterLabelEncode(TableLabelEncode):
replace_empty_cell_token=False,
merge_no_span_structure=False,
learn_empty_box=False,
point_num=2,
loc_reg_num=4,
**kwargs):
super(TableMasterLabelEncode, self).__init__(
max_text_length, character_dict_path, replace_empty_cell_token,
merge_no_span_structure, learn_empty_box, point_num, **kwargs)
merge_no_span_structure, learn_empty_box, loc_reg_num, **kwargs)
self.pad_idx = self.dict[self.pad_str]
self.unknown_idx = self.dict[self.unknown_str]
......@@ -739,13 +739,14 @@ class TableMasterLabelEncode(TableLabelEncode):
class TableBoxEncode(object):
def __init__(self, use_xywh=False, **kwargs):
self.use_xywh = use_xywh
def __init__(self, box_format='xyxy', **kwargs):
assert box_format in ['xywh', 'xyxy', 'xyxyxyxy']
self.box_format = box_format
def __call__(self, data):
img_height, img_width = data['image'].shape[:2]
bboxes = data['bboxes']
if self.use_xywh and bboxes.shape[1] == 4:
if self.box_format == 'xywh' and bboxes.shape[1] == 4:
bboxes = self.xyxy2xywh(bboxes)
bboxes[:, 0::2] /= img_width
bboxes[:, 1::2] /= img_height
......@@ -1217,6 +1218,7 @@ class ABINetLabelEncode(BaseRecLabelEncode):
dict_character = ['</s>'] + dict_character
return dict_character
class SPINLabelEncode(AttnLabelEncode):
""" Convert between text-label and text-index """
......@@ -1229,6 +1231,7 @@ class SPINLabelEncode(AttnLabelEncode):
super(SPINLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
self.lower = lower
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
......@@ -1248,4 +1251,4 @@ class SPINLabelEncode(AttnLabelEncode):
padded_text[:len(target)] = target
data['label'] = np.array(padded_text)
return data
\ No newline at end of file
return data
......@@ -206,7 +206,7 @@ class ResizeTableImage(object):
data['bboxes'] = data['bboxes'] * ratio
data['image'] = resize_img
data['src_img'] = img
data['shape'] = np.array([resize_h, resize_w, ratio, ratio])
data['shape'] = np.array([height, width, ratio, ratio])
data['max_len'] = self.max_len
return data
......
......@@ -51,7 +51,7 @@ from .basic_loss import DistanceLoss
from .combined_loss import CombinedLoss
# table loss
from .table_att_loss import TableAttentionLoss
from .table_att_loss import TableAttentionLoss, SLANetLoss
from .table_master_loss import TableMasterLoss
# vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
......@@ -63,7 +63,7 @@ def build_loss(config):
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss'
'TableMasterLoss', 'SPINAttentionLoss', 'SLANetLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
......
......@@ -22,65 +22,11 @@ from paddle.nn import functional as F
class TableAttentionLoss(nn.Layer):
def __init__(self,
structure_weight,
loc_weight,
use_giou=False,
giou_weight=1.0,
**kwargs):
def __init__(self, structure_weight, loc_weight, **kwargs):
super(TableAttentionLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
self.structure_weight = structure_weight
self.loc_weight = loc_weight
self.use_giou = use_giou
self.giou_weight = giou_weight
def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
'''
:param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
:param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
:return: loss
'''
ix1 = paddle.maximum(preds[:, 0], bbox[:, 0])
iy1 = paddle.maximum(preds[:, 1], bbox[:, 1])
ix2 = paddle.minimum(preds[:, 2], bbox[:, 2])
iy2 = paddle.minimum(preds[:, 3], bbox[:, 3])
iw = paddle.clip(ix2 - ix1 + 1e-3, 0., 1e10)
ih = paddle.clip(iy2 - iy1 + 1e-3, 0., 1e10)
# overlap
inters = iw * ih
# union
uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (
preds[:, 3] - preds[:, 1] + 1e-3) + (bbox[:, 2] - bbox[:, 0] + 1e-3
) * (bbox[:, 3] - bbox[:, 1] +
1e-3) - inters + eps
# ious
ious = inters / uni
ex1 = paddle.minimum(preds[:, 0], bbox[:, 0])
ey1 = paddle.minimum(preds[:, 1], bbox[:, 1])
ex2 = paddle.maximum(preds[:, 2], bbox[:, 2])
ey2 = paddle.maximum(preds[:, 3], bbox[:, 3])
ew = paddle.clip(ex2 - ex1 + 1e-3, 0., 1e10)
eh = paddle.clip(ey2 - ey1 + 1e-3, 0., 1e10)
# enclose erea
enclose = ew * eh + eps
giou = ious - (enclose - uni) / enclose
loss = 1 - giou
if reduction == 'mean':
loss = paddle.mean(loss)
elif reduction == 'sum':
loss = paddle.sum(loss)
else:
raise NotImplementedError
return loss
def forward(self, predicts, batch):
structure_probs = predicts['structure_probs']
......@@ -100,20 +46,48 @@ class TableAttentionLoss(nn.Layer):
loc_targets_mask = loc_targets_mask[:, 1:, :]
loc_loss = F.mse_loss(loc_preds * loc_targets_mask,
loc_targets) * self.loc_weight
if self.use_giou:
loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask,
loc_targets) * self.giou_weight
total_loss = structure_loss + loc_loss + loc_loss_giou
return {
'loss': total_loss,
"structure_loss": structure_loss,
"loc_loss": loc_loss,
"loc_loss_giou": loc_loss_giou
}
else:
total_loss = structure_loss + loc_loss
return {
'loss': total_loss,
"structure_loss": structure_loss,
"loc_loss": loc_loss
}
total_loss = structure_loss + loc_loss
return {
'loss': total_loss,
"structure_loss": structure_loss,
"loc_loss": loc_loss
}
class SLANetLoss(nn.Layer):
def __init__(self, structure_weight, loc_weight, loc_loss='mse', **kwargs):
super(SLANetLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='mean')
self.structure_weight = structure_weight
self.loc_weight = loc_weight
self.loc_loss = loc_loss
self.eps = 1e-12
def forward(self, predicts, batch):
structure_probs = predicts['structure_probs']
structure_targets = batch[1].astype("int64")
structure_targets = structure_targets[:, 1:]
structure_loss = self.loss_func(structure_probs, structure_targets)
structure_loss = paddle.mean(structure_loss) * self.structure_weight
loc_preds = predicts['loc_preds']
loc_targets = batch[2].astype("float32")
loc_targets_mask = batch[3].astype("float32")
loc_targets = loc_targets[:, 1:, :]
loc_targets_mask = loc_targets_mask[:, 1:, :]
loc_loss = F.smooth_l1_loss(
loc_preds * loc_targets_mask,
loc_targets * loc_targets_mask,
reduction='sum') * self.loc_weight
loc_loss = loc_loss / (loc_targets_mask.sum() + self.eps)
total_loss = structure_loss + loc_loss
return {
'loss': total_loss,
"structure_loss": structure_loss,
"loc_loss": loc_loss
}
......@@ -59,7 +59,7 @@ class TableMetric(object):
def __init__(self,
main_indicator='acc',
compute_bbox_metric=False,
point_num=2,
box_format='xyxy',
**kwargs):
"""
......@@ -70,7 +70,7 @@ class TableMetric(object):
self.structure_metric = TableStructureMetric()
self.bbox_metric = DetMetric() if compute_bbox_metric else None
self.main_indicator = main_indicator
self.point_num = point_num
self.box_format = box_format
self.reset()
def __call__(self, pred_label, batch=None, *args, **kwargs):
......@@ -129,10 +129,14 @@ class TableMetric(object):
self.bbox_metric.reset()
def format_box(self, box):
if self.point_num == 2:
if self.box_format == 'xyxy':
x1, y1, x2, y2 = box
box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
elif self.point_num == 4:
elif self.box_format == 'xywh':
x, y, w, h = box
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
elif self.box_format == 'xyxyxyxy':
x1, y1, x2, y2, x3, y3, x4, y4 = box
box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
return box
......@@ -21,7 +21,10 @@ def build_backbone(config, model_type):
from .det_resnet import ResNet
from .det_resnet_vd import ResNet_vd
from .det_resnet_vd_sast import ResNet_SAST
support_dict = ["MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST"]
from .det_pp_lcnet import PPLCNet
support_dict = [
"MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet"
]
if model_type == "table":
from .table_master_resnet import TableResNetExtra
support_dict.append('TableResNetExtra')
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import os
import paddle
import paddle.nn as nn
from paddle import ParamAttr
from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Dropout, Linear
from paddle.regularizer import L2Decay
from paddle.nn.initializer import KaimingNormal
from paddle.utils.download import get_path_from_url
MODEL_URLS = {
"PPLCNet_x0.25":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_25_pretrained.pdparams",
"PPLCNet_x0.35":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_35_pretrained.pdparams",
"PPLCNet_x0.5":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_5_pretrained.pdparams",
"PPLCNet_x0.75":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_75_pretrained.pdparams",
"PPLCNet_x1.0":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_0_pretrained.pdparams",
"PPLCNet_x1.5":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_5_pretrained.pdparams",
"PPLCNet_x2.0":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_0_pretrained.pdparams",
"PPLCNet_x2.5":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_5_pretrained.pdparams"
}
MODEL_STAGES_PATTERN = {
"PPLCNet": ["blocks2", "blocks3", "blocks4", "blocks5", "blocks6"]
}
__all__ = list(MODEL_URLS.keys())
# Each element(list) represents a depthwise block, which is composed of k, in_c, out_c, s, use_se.
# k: kernel_size
# in_c: input channel number in depthwise block
# out_c: output channel number in depthwise block
# s: stride in depthwise block
# use_se: whether to use SE block
NET_CONFIG = {
"blocks2":
# k, in_c, out_c, s, use_se
[[3, 16, 32, 1, False]],
"blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
"blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
"blocks5":
[[3, 128, 256, 2, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False],
[5, 256, 256, 1, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False]],
"blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]]
}
def make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNLayer(nn.Layer):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
num_groups=1):
super().__init__()
self.conv = Conv2D(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=num_groups,
weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False)
self.bn = BatchNorm(
num_filters,
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
self.hardswish = nn.Hardswish()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.hardswish(x)
return x
class DepthwiseSeparable(nn.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
dw_size=3,
use_se=False):
super().__init__()
self.use_se = use_se
self.dw_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=num_channels,
filter_size=dw_size,
stride=stride,
num_groups=num_channels)
if use_se:
self.se = SEModule(num_channels)
self.pw_conv = ConvBNLayer(
num_channels=num_channels,
filter_size=1,
num_filters=num_filters,
stride=1)
def forward(self, x):
x = self.dw_conv(x)
if self.use_se:
x = self.se(x)
x = self.pw_conv(x)
return x
class SEModule(nn.Layer):
def __init__(self, channel, reduction=4):
super().__init__()
self.avg_pool = AdaptiveAvgPool2D(1)
self.conv1 = Conv2D(
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
stride=1,
padding=0)
self.relu = nn.ReLU()
self.conv2 = Conv2D(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0)
self.hardsigmoid = nn.Hardsigmoid()
def forward(self, x):
identity = x
x = self.avg_pool(x)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.hardsigmoid(x)
x = paddle.multiply(x=identity, y=x)
return x
class PPLCNet(nn.Layer):
def __init__(self,
in_channels=3,
scale=1.0,
pretrained=False,
use_ssld=False):
super().__init__()
self.out_channels = [
int(NET_CONFIG["blocks3"][-1][2] * scale),
int(NET_CONFIG["blocks4"][-1][2] * scale),
int(NET_CONFIG["blocks5"][-1][2] * scale),
int(NET_CONFIG["blocks6"][-1][2] * scale)
]
self.scale = scale
self.conv1 = ConvBNLayer(
num_channels=in_channels,
filter_size=3,
num_filters=make_divisible(16 * scale),
stride=2)
self.blocks2 = nn.Sequential(* [
DepthwiseSeparable(
num_channels=make_divisible(in_c * scale),
num_filters=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se)
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks2"])
])
self.blocks3 = nn.Sequential(* [
DepthwiseSeparable(
num_channels=make_divisible(in_c * scale),
num_filters=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se)
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks3"])
])
self.blocks4 = nn.Sequential(* [
DepthwiseSeparable(
num_channels=make_divisible(in_c * scale),
num_filters=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se)
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks4"])
])
self.blocks5 = nn.Sequential(* [
DepthwiseSeparable(
num_channels=make_divisible(in_c * scale),
num_filters=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se)
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks5"])
])
self.blocks6 = nn.Sequential(* [
DepthwiseSeparable(
num_channels=make_divisible(in_c * scale),
num_filters=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se)
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks6"])
])
if pretrained:
self._load_pretrained(
MODEL_URLS['PPLCNet_x{}'.format(scale)], use_ssld=use_ssld)
def forward(self, x):
outs = []
x = self.conv1(x)
x = self.blocks2(x)
x = self.blocks3(x)
outs.append(x)
x = self.blocks4(x)
outs.append(x)
x = self.blocks5(x)
outs.append(x)
x = self.blocks6(x)
outs.append(x)
return outs
def _load_pretrained(self, pretrained_url, use_ssld=False):
if use_ssld:
pretrained_url = pretrained_url.replace("_pretrained",
"_ssld_pretrained")
print(pretrained_url)
local_weight_path = get_path_from_url(
pretrained_url, os.path.expanduser("~/.paddleclas/weights"))
param_state_dict = paddle.load(local_weight_path)
self.set_dict(param_state_dict)
return
......@@ -42,14 +42,15 @@ def build_head(config):
#kie head
from .kie_sdmgr_head import SDMGRHead
from .table_att_head import TableAttentionHead
from .table_att_head import TableAttentionHead, SLAHead
from .table_master_head import TableMasterHead
support_dict = [
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead'
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
'SLAHead'
]
#table head
......
......@@ -18,12 +18,26 @@ from __future__ import print_function
import paddle
import paddle.nn as nn
from paddle import ParamAttr
import paddle.nn.functional as F
import numpy as np
from .rec_att_head import AttentionGRUCell
def get_para_bias_attr(l2_decay, k):
if l2_decay > 0:
regularizer = paddle.regularizer.L2Decay(l2_decay)
stdv = 1.0 / math.sqrt(k * 1.0)
initializer = nn.initializer.Uniform(-stdv, stdv)
else:
regularizer = None
initializer = None
weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
return [weight_attr, bias_attr]
class TableAttentionHead(nn.Layer):
def __init__(self,
in_channels,
......@@ -32,7 +46,7 @@ class TableAttentionHead(nn.Layer):
in_max_len=488,
max_text_length=800,
out_channels=30,
point_num=2,
loc_reg_num=4,
**kwargs):
super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1]
......@@ -56,7 +70,7 @@ class TableAttentionHead(nn.Layer):
else:
self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
self.loc_generator = nn.Linear(self.input_size + hidden_size,
point_num * 2)
loc_reg_num)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
......@@ -129,3 +143,121 @@ class TableAttentionHead(nn.Layer):
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
class SLAHead(nn.Layer):
def __init__(self,
in_channels,
hidden_size,
out_channels=30,
max_text_length=500,
loc_reg_num=4,
fc_decay=0.0,
**kwargs):
"""
@param in_channels: input shape
@param hidden_size: hidden_size for RNN and Embedding
@param out_channels: num_classes to rec
@param max_text_length: max text pred
"""
super().__init__()
in_channels = in_channels[-1]
self.hidden_size = hidden_size
self.max_text_length = max_text_length
self.emb = self._char_to_onehot
self.num_embeddings = out_channels
# structure
self.structure_attention_cell = AttentionGRUCell(
in_channels, hidden_size, self.num_embeddings)
weight_attr, bias_attr = get_para_bias_attr(
l2_decay=fc_decay, k=hidden_size)
weight_attr1_1, bias_attr1_1 = get_para_bias_attr(
l2_decay=fc_decay, k=hidden_size)
weight_attr1_2, bias_attr1_2 = get_para_bias_attr(
l2_decay=fc_decay, k=hidden_size)
self.structure_generator = nn.Sequential(
nn.Linear(
self.hidden_size,
self.hidden_size,
weight_attr=weight_attr1_2,
bias_attr=bias_attr1_2),
nn.Linear(
hidden_size,
out_channels,
weight_attr=weight_attr,
bias_attr=bias_attr))
# loc
weight_attr1, bias_attr1 = get_para_bias_attr(
l2_decay=fc_decay, k=self.hidden_size)
weight_attr2, bias_attr2 = get_para_bias_attr(
l2_decay=fc_decay, k=self.hidden_size)
self.loc_generator = nn.Sequential(
nn.Linear(
self.hidden_size,
self.hidden_size,
weight_attr=weight_attr1,
bias_attr=bias_attr1),
nn.Linear(
self.hidden_size,
loc_reg_num,
weight_attr=weight_attr2,
bias_attr=bias_attr2),
nn.Sigmoid())
def forward(self, inputs, targets=None):
fea = inputs[-1]
batch_size = fea.shape[0]
# reshape
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], -1])
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
hidden = paddle.zeros((batch_size, self.hidden_size))
structure_preds = []
loc_preds = []
if self.training and targets is not None:
structure = targets[0]
for i in range(self.max_text_length + 1):
hidden, structure_step, loc_step = self._decode(structure[:, i],
fea, hidden)
structure_preds.append(structure_step)
loc_preds.append(loc_step)
else:
pre_chars = paddle.zeros(shape=[batch_size], dtype="int32")
max_text_length = paddle.to_tensor(self.max_text_length)
# for export
loc_step, structure_step = None, None
for i in range(max_text_length + 1):
hidden, structure_step, loc_step = self._decode(pre_chars, fea,
hidden)
pre_chars = structure_step.argmax(axis=1, dtype="int32")
structure_preds.append(structure_step)
loc_preds.append(loc_step)
structure_preds = paddle.stack(structure_preds, axis=1)
loc_preds = paddle.stack(loc_preds, axis=1)
if not self.training:
structure_preds = F.softmax(structure_preds)
return {'structure_probs': structure_preds, 'loc_preds': loc_preds}
def _decode(self, pre_chars, features, hidden):
"""
Predict table label and coordinates for each step
@param pre_chars: Table label in previous step
@param features:
@param hidden: hidden status in previous step
@return:
"""
emb_feature = self.emb(pre_chars)
# output shape is b * self.hidden_size
(output, hidden), alpha = self.structure_attention_cell(
hidden, features, emb_feature)
# structure
structure_step = self.structure_generator(output)
# loc
loc_step = self.loc_generator(output)
return hidden, structure_step, loc_step
def _char_to_onehot(self, input_char):
input_ont_hot = F.one_hot(input_char, self.num_embeddings)
return input_ont_hot
......@@ -37,7 +37,7 @@ class TableMasterHead(nn.Layer):
d_ff=2048,
dropout=0,
max_text_length=500,
point_num=2,
loc_reg_num=4,
**kwargs):
super(TableMasterHead, self).__init__()
hidden_size = in_channels[-1]
......@@ -50,7 +50,7 @@ class TableMasterHead(nn.Layer):
self.cls_fc = nn.Linear(hidden_size, out_channels)
self.bbox_fc = nn.Sequential(
# nn.Linear(hidden_size, hidden_size),
nn.Linear(hidden_size, point_num * 2),
nn.Linear(hidden_size, loc_reg_num),
nn.Sigmoid())
self.norm = nn.LayerNorm(hidden_size)
self.embedding = Embeddings(d_model=hidden_size, vocab=out_channels)
......@@ -59,7 +59,7 @@ class TableMasterHead(nn.Layer):
self.SOS = out_channels - 3
self.PAD = out_channels - 1
self.out_channels = out_channels
self.point_num = point_num
self.loc_reg_num = loc_reg_num
self.max_text_length = max_text_length
def make_mask(self, tgt):
......@@ -105,7 +105,7 @@ class TableMasterHead(nn.Layer):
output = paddle.zeros(
[input.shape[0], self.max_text_length + 1, self.out_channels])
bbox_output = paddle.zeros(
[input.shape[0], self.max_text_length + 1, self.point_num * 2])
[input.shape[0], self.max_text_length + 1, self.loc_reg_num])
max_text_length = paddle.to_tensor(self.max_text_length)
for i in range(max_text_length + 1):
target_mask = self.make_mask(input)
......
......@@ -25,9 +25,10 @@ def build_neck(config):
from .fpn import FPN
from .fce_fpn import FCEFPN
from .pren_fpn import PRENFPN
from .csp_pan import CSPPAN
support_dict = [
'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN',
'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN'
'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN'
]
module_name = config.pop('name')
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The code is based on:
# https://github.com/PaddlePaddle/PaddleDetection/blob/release%2F2.3/ppdet/modeling/necks/csp_pan.py
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
__all__ = ['CSPPAN']
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channel=96,
out_channel=96,
kernel_size=3,
stride=1,
groups=1,
act='leaky_relu'):
super(ConvBNLayer, self).__init__()
initializer = nn.initializer.KaimingUniform()
self.act = act
assert self.act in ['leaky_relu', "hard_swish"]
self.conv = nn.Conv2D(
in_channels=in_channel,
out_channels=out_channel,
kernel_size=kernel_size,
groups=groups,
padding=(kernel_size - 1) // 2,
stride=stride,
weight_attr=ParamAttr(initializer=initializer),
bias_attr=False)
self.bn = nn.BatchNorm2D(out_channel)
def forward(self, x):
x = self.bn(self.conv(x))
if self.act == "leaky_relu":
x = F.leaky_relu(x)
elif self.act == "hard_swish":
x = F.hardswish(x)
return x
class DPModule(nn.Layer):
"""
Depth-wise and point-wise module.
Args:
in_channel (int): The input channels of this Module.
out_channel (int): The output channels of this Module.
kernel_size (int): The conv2d kernel size of this Module.
stride (int): The conv2d's stride of this Module.
act (str): The activation function of this Module,
Now support `leaky_relu` and `hard_swish`.
"""
def __init__(self,
in_channel=96,
out_channel=96,
kernel_size=3,
stride=1,
act='leaky_relu'):
super(DPModule, self).__init__()
initializer = nn.initializer.KaimingUniform()
self.act = act
self.dwconv = nn.Conv2D(
in_channels=in_channel,
out_channels=out_channel,
kernel_size=kernel_size,
groups=out_channel,
padding=(kernel_size - 1) // 2,
stride=stride,
weight_attr=ParamAttr(initializer=initializer),
bias_attr=False)
self.bn1 = nn.BatchNorm2D(out_channel)
self.pwconv = nn.Conv2D(
in_channels=out_channel,
out_channels=out_channel,
kernel_size=1,
groups=1,
padding=0,
weight_attr=ParamAttr(initializer=initializer),
bias_attr=False)
self.bn2 = nn.BatchNorm2D(out_channel)
def act_func(self, x):
if self.act == "leaky_relu":
x = F.leaky_relu(x)
elif self.act == "hard_swish":
x = F.hardswish(x)
return x
def forward(self, x):
x = self.act_func(self.bn1(self.dwconv(x)))
x = self.act_func(self.bn2(self.pwconv(x)))
return x
class DarknetBottleneck(nn.Layer):
"""The basic bottleneck block used in Darknet.
Each Block consists of two ConvModules and the input is added to the
final output. Each ConvModule is composed of Conv, BN, and act.
The first convLayer has filter size of 1x1 and the second one has the
filter size of 3x3.
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
expansion (int): The kernel size of the convolution. Default: 0.5
add_identity (bool): Whether to add identity to the out.
Default: True
use_depthwise (bool): Whether to use depthwise separable convolution.
Default: False
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
expansion=0.5,
add_identity=True,
use_depthwise=False,
act="leaky_relu"):
super(DarknetBottleneck, self).__init__()
hidden_channels = int(out_channels * expansion)
conv_func = DPModule if use_depthwise else ConvBNLayer
self.conv1 = ConvBNLayer(
in_channel=in_channels,
out_channel=hidden_channels,
kernel_size=1,
act=act)
self.conv2 = conv_func(
in_channel=hidden_channels,
out_channel=out_channels,
kernel_size=kernel_size,
stride=1,
act=act)
self.add_identity = \
add_identity and in_channels == out_channels
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.conv2(out)
if self.add_identity:
return out + identity
else:
return out
class CSPLayer(nn.Layer):
"""Cross Stage Partial Layer.
Args:
in_channels (int): The input channels of the CSP layer.
out_channels (int): The output channels of the CSP layer.
expand_ratio (float): Ratio to adjust the number of channels of the
hidden layer. Default: 0.5
num_blocks (int): Number of blocks. Default: 1
add_identity (bool): Whether to add identity in blocks.
Default: True
use_depthwise (bool): Whether to depthwise separable convolution in
blocks. Default: False
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
expand_ratio=0.5,
num_blocks=1,
add_identity=True,
use_depthwise=False,
act="leaky_relu"):
super().__init__()
mid_channels = int(out_channels * expand_ratio)
self.main_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act)
self.short_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act)
self.final_conv = ConvBNLayer(
2 * mid_channels, out_channels, 1, act=act)
self.blocks = nn.Sequential(* [
DarknetBottleneck(
mid_channels,
mid_channels,
kernel_size,
1.0,
add_identity,
use_depthwise,
act=act) for _ in range(num_blocks)
])
def forward(self, x):
x_short = self.short_conv(x)
x_main = self.main_conv(x)
x_main = self.blocks(x_main)
x_final = paddle.concat((x_main, x_short), axis=1)
return self.final_conv(x_final)
class Channel_T(nn.Layer):
def __init__(self,
in_channels=[116, 232, 464],
out_channels=96,
act="leaky_relu"):
super(Channel_T, self).__init__()
self.convs = nn.LayerList()
for i in range(len(in_channels)):
self.convs.append(
ConvBNLayer(
in_channels[i], out_channels, 1, act=act))
def forward(self, x):
outs = [self.convs[i](x[i]) for i in range(len(x))]
return outs
class CSPPAN(nn.Layer):
"""Path Aggregation Network with CSP module.
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale)
kernel_size (int): The conv2d kernel size of this Module.
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 1
use_depthwise (bool): Whether to depthwise separable convolution in
blocks. Default: True
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=5,
num_csp_blocks=1,
use_depthwise=True,
act='hard_swish'):
super(CSPPAN, self).__init__()
self.in_channels = in_channels
self.out_channels = [out_channels] * len(in_channels)
conv_func = DPModule if use_depthwise else ConvBNLayer
self.conv_t = Channel_T(in_channels, out_channels, act=act)
# build top-down blocks
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.top_down_blocks = nn.LayerList()
for idx in range(len(in_channels) - 1, 0, -1):
self.top_down_blocks.append(
CSPLayer(
out_channels * 2,
out_channels,
kernel_size=kernel_size,
num_blocks=num_csp_blocks,
add_identity=False,
use_depthwise=use_depthwise,
act=act))
# build bottom-up blocks
self.downsamples = nn.LayerList()
self.bottom_up_blocks = nn.LayerList()
for idx in range(len(in_channels) - 1):
self.downsamples.append(
conv_func(
out_channels,
out_channels,
kernel_size=kernel_size,
stride=2,
act=act))
self.bottom_up_blocks.append(
CSPLayer(
out_channels * 2,
out_channels,
kernel_size=kernel_size,
num_blocks=num_csp_blocks,
add_identity=False,
use_depthwise=use_depthwise,
act=act))
def forward(self, inputs):
"""
Args:
inputs (tuple[Tensor]): input features.
Returns:
tuple[Tensor]: CSPPAN features.
"""
assert len(inputs) == len(self.in_channels)
inputs = self.conv_t(inputs)
# top-down path
inner_outs = [inputs[-1]]
for idx in range(len(self.in_channels) - 1, 0, -1):
feat_heigh = inner_outs[0]
feat_low = inputs[idx - 1]
upsample_feat = F.upsample(
feat_heigh, size=feat_low.shape[2:4], mode="nearest")
inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx](
paddle.concat([upsample_feat, feat_low], 1))
inner_outs.insert(0, inner_out)
# bottom-up path
outs = [inner_outs[0]]
for idx in range(len(self.in_channels) - 1):
feat_low = outs[-1]
feat_height = inner_outs[idx + 1]
downsample_feat = self.downsamples[idx](feat_low)
out = self.bottom_up_blocks[idx](paddle.concat(
[downsample_feat, feat_height], 1))
outs.append(out)
return tuple(outs)
......@@ -23,7 +23,7 @@ class TableLabelDecode(AttnLabelDecode):
def __init__(self, character_dict_path, **kwargs):
super(TableLabelDecode, self).__init__(character_dict_path)
self.td_token = ['<td>', '<td', '<eb></eb>', '<td></td>']
self.td_token = ['<td>', '<td', '<td></td>']
def __call__(self, preds, batch=None):
structure_probs = preds['structure_probs']
......@@ -114,10 +114,8 @@ class TableLabelDecode(AttnLabelDecode):
def _bbox_decode(self, bbox, shape):
h, w, ratio_h, ratio_w, pad_h, pad_w = shape
src_h = h / ratio_h
src_w = w / ratio_w
bbox[0::2] *= src_w
bbox[1::2] *= src_h
bbox[0::2] *= w
bbox[1::2] *= h
return bbox
......@@ -157,4 +155,7 @@ class TableMasterLabelDecode(TableLabelDecode):
bbox[1::2] *= h
bbox[0::2] /= ratio_w
bbox[1::2] /= ratio_h
x, y, w, h = bbox
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
bbox = np.array([x1, y1, x2, y2])
return bbox
......@@ -113,14 +113,10 @@ def draw_re_results(image,
return np.array(img_new)
def draw_rectangle(img_path, boxes, use_xywh=False):
def draw_rectangle(img_path, boxes):
img = cv2.imread(img_path)
img_show = img.copy()
for box in boxes.astype(int):
if use_xywh:
x, y, w, h = box
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
else:
x1, y1, x2, y2 = box
x1, y1, x2, y2 = box
cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2)
return img_show
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from scipy.special import softmax
def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
"""
Args:
box_scores (N, 5): boxes in corner-form and probabilities.
iou_threshold: intersection over union threshold.
top_k: keep top_k results. If k <= 0, keep all the results.
candidate_size: only consider the candidates with the highest scores.
Returns:
picked: a list of indexes of the kept boxes
"""
scores = box_scores[:, -1]
boxes = box_scores[:, :-1]
picked = []
indexes = np.argsort(scores)
indexes = indexes[-candidate_size:]
while len(indexes) > 0:
current = indexes[-1]
picked.append(current)
if 0 < top_k == len(picked) or len(indexes) == 1:
break
current_box = boxes[current, :]
indexes = indexes[:-1]
rest_boxes = boxes[indexes, :]
iou = iou_of(
rest_boxes,
np.expand_dims(
current_box, axis=0), )
indexes = indexes[iou <= iou_threshold]
return box_scores[picked, :]
def iou_of(boxes0, boxes1, eps=1e-5):
"""Return intersection-over-union (Jaccard index) of boxes.
Args:
boxes0 (N, 4): ground truth boxes.
boxes1 (N or 1, 4): predicted boxes.
eps: a small number to avoid 0 as denominator.
Returns:
iou (N): IoU values.
"""
overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])
overlap_area = area_of(overlap_left_top, overlap_right_bottom)
area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
return overlap_area / (area0 + area1 - overlap_area + eps)
def area_of(left_top, right_bottom):
"""Compute the areas of rectangles given two corners.
Args:
left_top (N, 2): left top corner.
right_bottom (N, 2): right bottom corner.
Returns:
area (N): return the area.
"""
hw = np.clip(right_bottom - left_top, 0.0, None)
return hw[..., 0] * hw[..., 1]
class PicoDetPostProcess(object):
"""
Args:
input_shape (int): network input image size
ori_shape (int): ori image shape of before padding
scale_factor (float): scale factor of ori image
enable_mkldnn (bool): whether to open MKLDNN
"""
def __init__(self,
input_shape,
ori_shape,
scale_factor,
strides=[8, 16, 32, 64],
score_threshold=0.4,
nms_threshold=0.5,
nms_top_k=1000,
keep_top_k=100):
self.ori_shape = ori_shape
self.input_shape = input_shape
self.scale_factor = scale_factor
self.strides = strides
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
self.nms_top_k = nms_top_k
self.keep_top_k = keep_top_k
def warp_boxes(self, boxes, ori_shape):
"""Apply transform to boxes
"""
width, height = ori_shape[1], ori_shape[0]
n = len(boxes)
if n:
# warp points
xy = np.ones((n * 4, 3))
xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
n * 4, 2) # x1y1, x2y2, x1y2, x2y1
# xy = xy @ M.T # transform
xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
# create new boxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
xy = np.concatenate(
(x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
# clip boxes
xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
return xy.astype(np.float32)
else:
return boxes
def __call__(self, scores, raw_boxes):
batch_size = raw_boxes[0].shape[0]
reg_max = int(raw_boxes[0].shape[-1] / 4 - 1)
out_boxes_num = []
out_boxes_list = []
for batch_id in range(batch_size):
# generate centers
decode_boxes = []
select_scores = []
for stride, box_distribute, score in zip(self.strides, raw_boxes,
scores):
box_distribute = box_distribute[batch_id]
score = score[batch_id]
# centers
fm_h = self.input_shape[0] / stride
fm_w = self.input_shape[1] / stride
h_range = np.arange(fm_h)
w_range = np.arange(fm_w)
ww, hh = np.meshgrid(w_range, h_range)
ct_row = (hh.flatten() + 0.5) * stride
ct_col = (ww.flatten() + 0.5) * stride
center = np.stack((ct_col, ct_row, ct_col, ct_row), axis=1)
# box distribution to distance
reg_range = np.arange(reg_max + 1)
box_distance = box_distribute.reshape((-1, reg_max + 1))
box_distance = softmax(box_distance, axis=1)
box_distance = box_distance * np.expand_dims(reg_range, axis=0)
box_distance = np.sum(box_distance, axis=1).reshape((-1, 4))
box_distance = box_distance * stride
# top K candidate
topk_idx = np.argsort(score.max(axis=1))[::-1]
topk_idx = topk_idx[:self.nms_top_k]
center = center[topk_idx]
score = score[topk_idx]
box_distance = box_distance[topk_idx]
# decode box
decode_box = center + [-1, -1, 1, 1] * box_distance
select_scores.append(score)
decode_boxes.append(decode_box)
# nms
bboxes = np.concatenate(decode_boxes, axis=0)
confidences = np.concatenate(select_scores, axis=0)
picked_box_probs = []
picked_labels = []
for class_index in range(0, confidences.shape[1]):
probs = confidences[:, class_index]
mask = probs > self.score_threshold
probs = probs[mask]
if probs.shape[0] == 0:
continue
subset_boxes = bboxes[mask, :]
box_probs = np.concatenate(
[subset_boxes, probs.reshape(-1, 1)], axis=1)
box_probs = hard_nms(
box_probs,
iou_threshold=self.nms_threshold,
top_k=self.keep_top_k, )
picked_box_probs.append(box_probs)
picked_labels.extend([class_index] * box_probs.shape[0])
if len(picked_box_probs) == 0:
out_boxes_list.append(np.empty((0, 4)))
out_boxes_num.append(0)
else:
picked_box_probs = np.concatenate(picked_box_probs)
# resize output boxes
picked_box_probs[:, :4] = self.warp_boxes(
picked_box_probs[:, :4], self.ori_shape[batch_id])
im_scale = np.concatenate([
self.scale_factor[batch_id][::-1],
self.scale_factor[batch_id][::-1]
])
picked_box_probs[:, :4] /= im_scale
# clas score box
out_boxes_list.append(
np.concatenate(
[
np.expand_dims(
np.array(picked_labels),
axis=-1), np.expand_dims(
picked_box_probs[:, 4], axis=-1),
picked_box_probs[:, :4]
],
axis=1))
out_boxes_num.append(len(picked_labels))
out_boxes_list = np.concatenate(out_boxes_list, axis=0)
out_boxes_num = np.asarray(out_boxes_num).astype(np.int32)
return out_boxes_list, out_boxes_num
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import numpy as np
import time
import tools.infer.utility as utility
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppstructure.utility import parse_args
from picodet_postprocess import PicoDetPostProcess
logger = get_logger()
class LayoutPredictor(object):
def __init__(self, args):
pre_process_list = [{
'Resize': {
'size': [800, 608]
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image']
}
}]
# postprocess_params = {
# 'name': 'LayoutPostProcess',
# "character_dict_path": args.layout_dict_path,
# }
self.preprocess_op = create_operators(pre_process_list)
# self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'layout', logger)
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
data = transform(data, self.preprocess_op)
img = data[0]
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
img = img.copy()
preds, elapse = 0, 1
starttime = time.time()
self.input_tensor.copy_from_cpu(img)
self.predictor.run()
# outputs = []
# for output_tensor in self.output_tensors:
# output = output_tensor.copy_to_cpu()
# outputs.append(output)
np_score_list, np_boxes_list = [], []
output_names = self.predictor.get_output_names()
num_outs = int(len(output_names) / 2)
for out_idx in range(num_outs):
np_score_list.append(
self.predictor.get_output_handle(output_names[out_idx])
.copy_to_cpu())
np_boxes_list.append(
self.predictor.get_output_handle(output_names[
out_idx + num_outs]).copy_to_cpu())
# result = dict(boxes=np_score_list, boxes_num=np_boxes_list)
postprocessor = PicoDetPostProcess(
(800, 608), [[800., 608.]],
np.array([[1.010101, 0.99346405]]),
strides=[8, 16, 32, 64],
nms_threshold=0.5)
np_boxes, np_boxes_num = postprocessor(np_score_list, np_boxes_list)
result = dict(boxes=np_boxes, boxes_num=np_boxes_num)
# print(result)
im_bboxes_num = result['boxes_num'][0]
# print('im_bboxes_num:',im_bboxes_num)
bboxs = result['boxes'][0:0 + im_bboxes_num, :]
threshold = 0.5
expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
np_boxes = np_boxes[expect_boxes, :]
preds = []
id2label = {1: 'text', 2: 'title', 3: 'list', 4: 'table', 5: 'figure'}
for dt in np_boxes:
clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
label = id2label[clsid + 1]
result_di = {'bbox': bbox, 'label': label}
preds.append(result_di)
# print('result_di',result_di)
# print('clsid, bbox, score:',clsid, bbox, score)
elapse = time.time() - starttime
return preds, elapse
def main(args):
image_file_list = get_image_file_list(args.image_dir)
layout_predictor = LayoutPredictor(args)
count = 0
total_time = 0
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
layout_res, elapse = layout_predictor(img)
logger.info("result: {}".format(layout_res))
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))
if __name__ == "__main__":
main(parse_args())
......@@ -18,7 +18,7 @@ import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
......@@ -32,6 +32,7 @@ from attrdict import AttrDict
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
from tools.infer.predict_system import TextSystem
from ppstructure.layout.predict_layout import LayoutPredictor
from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.utility import parse_args, draw_structure_result
from ppstructure.recovery.recovery_to_doc import convert_info_docx
......@@ -51,28 +52,14 @@ class StructureSystem(object):
"When args.layout is false, args.ocr is automatically set to false"
)
args.drop_score = 0
# init layout and ocr model
# init model
self.layout_predictor = None
self.text_system = None
self.table_system = None
if args.layout:
import layoutparser as lp
config_path = None
model_path = None
if os.path.isdir(args.layout_path_model):
model_path = args.layout_path_model
else:
config_path = args.layout_path_model
self.table_layout = lp.PaddleDetectionLayoutModel(
config_path=config_path,
model_path=model_path,
label_map=args.layout_label_map,
threshold=0.5,
enable_mkldnn=args.enable_mkldnn,
enforce_cpu=not args.use_gpu,
thread_num=args.cpu_threads)
self.layout_predictor = LayoutPredictor(args)
if args.ocr:
self.text_system = TextSystem(args)
else:
self.table_layout = None
if args.table:
if self.text_system is not None:
self.table_system = TableSystem(
......@@ -80,38 +67,59 @@ class StructureSystem(object):
self.text_system.text_recognizer)
else:
self.table_system = TableSystem(args)
else:
self.table_system = None
elif self.mode == 'vqa':
raise NotImplementedError
def __call__(self, img, return_ocr_result_in_table=False):
time_dict = {
'layout': 0,
'table': 0,
'table_match': 0,
'det': 0,
'rec': 0,
'vqa': 0,
'all': 0
}
start = time.time()
if self.mode == 'structure':
ori_im = img.copy()
if self.table_layout is not None:
layout_res = self.table_layout.detect(img[..., ::-1])
if self.layout_predictor is not None:
layout_res, elapse = self.layout_predictor(img)
time_dict['layout'] += elapse
else:
h, w = ori_im.shape[:2]
layout_res = [AttrDict(coordinates=[0, 0, w, h], type='Table')]
layout_res = [dict(bbox=None, label='table')]
res_list = []
for region in layout_res:
res = ''
x1, y1, x2, y2 = region.coordinates
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
roi_img = ori_im[y1:y2, x1:x2, :]
if region.type == 'Table':
if region['bbox'] is not None:
x1, y1, x2, y2 = region['bbox']
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
roi_img = ori_im[y1:y2, x1:x2, :]
else:
x1, y1, x2, y2 = 0, 0, w, h
roi_img = ori_im
if region['label'] == 'table':
if self.table_system is not None:
res = self.table_system(roi_img,
return_ocr_result_in_table)
res, table_time_dict = self.table_system(
roi_img, return_ocr_result_in_table)
time_dict['table'] += table_time_dict['table']
time_dict['table_match'] += table_time_dict['match']
time_dict['det'] += table_time_dict['det']
time_dict['rec'] += table_time_dict['rec']
else:
if self.text_system is not None:
if args.recovery:
wht_im = np.ones(ori_im.shape, dtype=ori_im.dtype)
wht_im[y1:y2, x1:x2, :] = roi_img
filter_boxes, filter_rec_res = self.text_system(wht_im)
filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
wht_im)
else:
filter_boxes, filter_rec_res = self.text_system(roi_img)
filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
roi_img)
time_dict['det'] += ocr_time_dict['det']
time_dict['rec'] += ocr_time_dict['rec']
# remove style char
style_token = [
'<strike>', '<strike>', '<sup>', '</sub>', '<b>',
......@@ -133,15 +141,17 @@ class StructureSystem(object):
'text_region': box.tolist()
})
res_list.append({
'type': region.type,
'type': region['label'].lower(),
'bbox': [x1, y1, x2, y2],
'img': roi_img,
'res': res
})
return res_list
end = time.time()
time_dict['all'] = end - start
return res_list, time_dict
elif self.mode == 'vqa':
raise NotImplementedError
return None
return None, None
def save_structure_res(res, save_folder, img_name):
......@@ -156,12 +166,12 @@ def save_structure_res(res, save_folder, img_name):
roi_img = region.pop('img')
f.write('{}\n'.format(json.dumps(region)))
if region['type'] == 'Table' and len(region[
if region['type'] == 'table' and len(region[
'res']) > 0 and 'html' in region['res']:
excel_path = os.path.join(excel_save_folder,
'{}.xlsx'.format(region['bbox']))
to_excel(region['res']['html'], excel_path)
elif region['type'] == 'Figure':
elif region['type'] == 'figure':
img_path = os.path.join(excel_save_folder,
'{}.jpg'.format(region['bbox']))
cv2.imwrite(img_path, roi_img)
......@@ -188,7 +198,7 @@ def main(args):
logger.error("error in loading image:{}".format(image_file))
continue
starttime = time.time()
res = structure_sys(img)
res, time_dict = structure_sys(img)
if structure_sys.mode == 'structure':
save_structure_res(res, save_folder, img_name)
......@@ -201,7 +211,7 @@ def main(args):
cv2.imwrite(img_save_path, draw_img)
logger.info('result save to {}'.format(img_save_path))
if args.recovery:
convert_info_docx(img, res, save_folder, img_name)
convert_info_docx(img, res, save_folder, img_name)
elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse))
......
......@@ -13,12 +13,14 @@
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
import cv2
import json
import pickle
import paddle
from tqdm import tqdm
from ppstructure.table.table_metric import TEDS
from ppstructure.table.predict_table import TableSystem
......@@ -33,40 +35,74 @@ def parse_args():
parser.add_argument("--gt_path", type=str)
return parser.parse_args()
def main(gt_path, img_root, args):
teds = TEDS(n_jobs=16)
def load_txt(txt_path):
pred_html_dict = {}
if not os.path.exists(txt_path):
return pred_html_dict
with open(txt_path, encoding='utf-8') as f:
lines = f.readlines()
for line in lines:
line = line.strip().split('\t')
img_name, pred_html = line
pred_html_dict[img_name] = pred_html
return pred_html_dict
def load_result(path):
data = {}
if os.path.exists(path):
data = pickle.load(open(path, 'rb'))
return data
def save_result(path, data):
old_data = load_result(path)
old_data.update(data)
with open(path, 'wb') as f:
pickle.dump(old_data, f)
def main(gt_path, img_root, args):
os.makedirs(args.output, exist_ok=True)
# init TableSystem
text_sys = TableSystem(args)
jsons_gt = json.load(open(gt_path)) # gt
# load gt and preds html result
gt_html_dict = load_txt(gt_path)
ocr_result = load_result(os.path.join(args.output, 'ocr.pickle'))
structure_result = load_result(
os.path.join(args.output, 'structure.pickle'))
pred_htmls = []
gt_htmls = []
for img_name in tqdm(jsons_gt):
# read image
img = cv2.imread(os.path.join(img_root,img_name))
pred_html = text_sys(img)
pred_htmls.append(pred_html)
for img_name, gt_html in tqdm(gt_html_dict.items()):
img = cv2.imread(os.path.join(img_root, img_name))
# run ocr and save result
if img_name not in ocr_result:
dt_boxes, rec_res, _, _ = text_sys._ocr(img)
ocr_result[img_name] = [dt_boxes, rec_res]
save_result(os.path.join(args.output, 'ocr.pickle'), ocr_result)
# run structure and save result
if img_name not in structure_result:
structure_res, _ = text_sys._structure(img)
structure_result[img_name] = structure_res
save_result(
os.path.join(args.output, 'structure.pickle'), structure_result)
dt_boxes, rec_res = ocr_result[img_name]
structure_res = structure_result[img_name]
# match ocr and structure
pred_html = text_sys.match(structure_res, dt_boxes, rec_res)
gt_structures, gt_bboxes, gt_contents = jsons_gt[img_name]
gt_html, gt = get_gt_html(gt_structures, gt_contents)
pred_htmls.append(pred_html)
gt_htmls.append(gt_html)
scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
logger.info('teds:', sum(scores) / len(scores))
def get_gt_html(gt_structures, gt_contents):
end_html = []
td_index = 0
for tag in gt_structures:
if '</td>' in tag:
if gt_contents[td_index] != []:
end_html.extend(gt_contents[td_index])
end_html.append(tag)
td_index += 1
else:
end_html.append(tag)
return ''.join(end_html), end_html
# compute teds
teds = TEDS(n_jobs=16)
scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
logger.info('teds: {}'.format(sum(scores) / len(scores)))
if __name__ == '__main__':
args = parse_args()
main(args.gt_path,args.image_dir, args)
main(args.gt_path, args.image_dir, args)
import json
from ppstructure.table.table_master_match import deal_eb_token, deal_bb
def distance(box_1, box_2):
x1, y1, x2, y2 = box_1
x3, y3, x4, y4 = box_2
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
dis_2 = abs(x3 - x1) + abs(y3 - y1)
dis_3 = abs(x4- x2) + abs(y4 - y2)
return dis + min(dis_2, dis_3)
x1, y1, x2, y2 = box_1
x3, y3, x4, y4 = box_2
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
dis_2 = abs(x3 - x1) + abs(y3 - y1)
dis_3 = abs(x4 - x2) + abs(y4 - y2)
return dis + min(dis_2, dis_3)
def compute_iou(rec1, rec2):
"""
......@@ -18,23 +22,22 @@ def compute_iou(rec1, rec2):
# computing area of each rectangles
S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
# computing the sum_area
sum_area = S_rec1 + S_rec2
# find the each edge of intersect rectangle
left_line = max(rec1[1], rec2[1])
right_line = min(rec1[3], rec2[3])
top_line = max(rec1[0], rec2[0])
bottom_line = min(rec1[2], rec2[2])
# judge if there is an intersect
if left_line >= right_line or top_line >= bottom_line:
return 0.0
else:
intersect = (right_line - left_line) * (bottom_line - top_line)
return (intersect / (sum_area - intersect))*1.0
return (intersect / (sum_area - intersect)) * 1.0
def matcher_merge(ocr_bboxes, pred_bboxes):
......@@ -45,15 +48,18 @@ def matcher_merge(ocr_bboxes, pred_bboxes):
distances = []
for j, pred_box in enumerate(pred_bboxes):
# compute l1 distence and IOU between two boxes
distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box)))
distances.append((distance(gt_box, pred_box),
1. - compute_iou(gt_box, pred_box)))
sorted_distances = distances.copy()
# select nearest cell
sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0]))
if distances.index(sorted_distances[0]) not in matched.keys():
sorted_distances = sorted(
sorted_distances, key=lambda item: (item[1], item[0]))
if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
else:
matched[distances.index(sorted_distances[0])].append(i)
return matched#, sum(ious) / len(ious)
return matched #, sum(ious) / len(ious)
def complex_num(pred_bboxes):
complex_nums = []
......@@ -67,6 +73,7 @@ def complex_num(pred_bboxes):
complex_nums.append(temp_ious[distances.index(min(distances))])
return sum(complex_nums) / len(complex_nums)
def get_rows(pred_bboxes):
pre_bbox = pred_bboxes[0]
res = []
......@@ -81,7 +88,9 @@ def get_rows(pred_bboxes):
for i in range(step):
pred_bboxes.pop(0)
return res, pred_bboxes
def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
ys_1 = []
ys_2 = []
for box in pred_bboxes:
......@@ -95,12 +104,14 @@ def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
box[3] = min_y_2
re_boxes.append(box)
return re_boxes
def matcher_refine_row(gt_bboxes, pred_bboxes):
before_refine_pred_bboxes = pred_bboxes.copy()
pred_bboxes = []
while(len(before_refine_pred_bboxes) != 0):
row_bboxes, before_refine_pred_bboxes = get_rows(before_refine_pred_bboxes)
while (len(before_refine_pred_bboxes) != 0):
row_bboxes, before_refine_pred_bboxes = get_rows(
before_refine_pred_bboxes)
print(row_bboxes)
pred_bboxes.extend(refine_rows(row_bboxes))
all_dis = []
......@@ -114,12 +125,11 @@ def matcher_refine_row(gt_bboxes, pred_bboxes):
#temp_ious.append(compute_iou(gt_box, pred_box))
#all_dis.append(min(distances))
#ious.append(temp_ious[distances.index(min(distances))])
if distances.index(min(distances)) not in matched.keys():
if distances.index(min(distances)) not in matched.keys():
matched[distances.index(min(distances))] = [i]
else:
matched[distances.index(min(distances))].append(i)
return matched#, sum(ious) / len(ious)
return matched #, sum(ious) / len(ious)
#先挑选出一行,再进行匹配
......@@ -128,29 +138,30 @@ def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes):
delete_gt_bboxes = gt_bboxes.copy()
match_bboxes_ready = []
matched = {}
while(len(delete_gt_bboxes) != 0):
while (len(delete_gt_bboxes) != 0):
row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes)
row_bboxes = sorted(row_bboxes, key = lambda key: key[0])
row_bboxes = sorted(row_bboxes, key=lambda key: key[0])
if len(pred_bboxes_rows) > 0:
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
print(row_bboxes)
for i, gt_box in enumerate(row_bboxes):
#print(gt_box)
pred_distances = []
distances = []
distances = []
for pred_bbox in pred_bboxes:
pred_distances.append(distance(gt_box, pred_bbox))
for j, pred_box in enumerate(match_bboxes_ready):
distances.append(distance(gt_box, pred_box))
index = pred_distances.index(min(distances))
#print('index', index)
if index not in matched.keys():
if index not in matched.keys():
matched[index] = [gt_box_index]
else:
matched[index].append(gt_box_index)
gt_box_index += 1
return matched
def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
'''
gt_bboxes: 排序后
......@@ -161,7 +172,7 @@ def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
match_bboxes_ready = []
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
for i, gt_box in enumerate(gt_bboxes):
pred_distances = []
for pred_bbox in pred_bboxes:
pred_distances.append(distance(gt_box, pred_bbox))
......@@ -184,9 +195,143 @@ def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
#print(gt_box, index)
#match_bboxes_ready.pop(distances.index(min(distances)))
print(gt_box, match_bboxes_ready[distances.index(min(distances))])
if index not in matched.keys():
if index not in matched.keys():
matched[index] = [i]
else:
matched[index].append(i)
pre_bbox = gt_box
return matched
class TableMatch:
def __init__(self, filter_ocr_result=False, use_master=False):
self.filter_ocr_result = filter_ocr_result
self.use_master = use_master
def __call__(self, structure_res, dt_boxes, rec_res):
pred_structures, pred_bboxes = structure_res
if self.filter_ocr_result:
dt_boxes, rec_res = self.filter_ocr_result(pred_bboxes, dt_boxes,
rec_res)
matched_index = self.match_result(dt_boxes, pred_bboxes)
if self.use_master:
pred_html, pred = self.get_pred_html_master(pred_structures,
matched_index, rec_res)
else:
pred_html, pred = self.get_pred_html(pred_structures, matched_index,
rec_res)
return pred_html
def match_result(self, dt_boxes, pred_bboxes):
matched = {}
for i, gt_box in enumerate(dt_boxes):
# gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
distances = []
for j, pred_box in enumerate(pred_bboxes):
distances.append((distance(gt_box, pred_box),
1. - compute_iou(gt_box, pred_box)
)) # 获取两两cell之间的L1距离和 1- IOU
sorted_distances = distances.copy()
# 根据距离和IOU挑选最"近"的cell
sorted_distances = sorted(
sorted_distances, key=lambda item: (item[1], item[0]))
if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
else:
matched[distances.index(sorted_distances[0])].append(i)
return matched
def get_pred_html(self, pred_structures, matched_index, ocr_contents):
end_html = []
td_index = 0
for tag in pred_structures:
if '</td>' in tag:
if '<td></td>' == tag:
end_html.extend('<td>')
if td_index in matched_index.keys():
b_with = False
if '<b>' in ocr_contents[matched_index[td_index][
0]] and len(matched_index[td_index]) > 1:
b_with = True
end_html.extend('<b>')
for i, td_index_index in enumerate(matched_index[td_index]):
content = ocr_contents[td_index_index][0]
if len(matched_index[td_index]) > 1:
if len(content) == 0:
continue
if content[0] == ' ':
content = content[1:]
if '<b>' in content:
content = content[3:]
if '</b>' in content:
content = content[:-4]
if len(content) == 0:
continue
if i != len(matched_index[
td_index]) - 1 and ' ' != content[-1]:
content += ' '
end_html.extend(content)
if b_with:
end_html.extend('</b>')
if '<td></td>' == tag:
end_html.append('</td>')
else:
end_html.append(tag)
td_index += 1
else:
end_html.append(tag)
return ''.join(end_html), end_html
def get_pred_html_master(self, pred_structures, matched_index,
ocr_contents):
end_html = []
td_index = 0
for token in pred_structures:
if '</td>' in token:
txt = ''
b_with = False
if td_index in matched_index.keys():
if '<b>' in ocr_contents[matched_index[td_index][
0]] and len(matched_index[td_index]) > 1:
b_with = True
for i, td_index_index in enumerate(matched_index[td_index]):
content = ocr_contents[td_index_index][0]
if len(matched_index[td_index]) > 1:
if len(content) == 0:
continue
if content[0] == ' ':
content = content[1:]
if '<b>' in content:
content = content[3:]
if '</b>' in content:
content = content[:-4]
if len(content) == 0:
continue
if i != len(matched_index[
td_index]) - 1 and ' ' != content[-1]:
content += ' '
txt += content
if b_with:
txt = '<b>{}</b>'.format(txt)
if '<td></td>' == token:
token = '<td>{}</td>'.format(txt)
else:
token = '{}</td>'.format(txt)
td_index += 1
token = deal_eb_token(token)
end_html.append(token)
html = ''.join(end_html)
html = deal_bb(html)
return html, end_html
def filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
y1 = pred_bboxes[:, 1::2].min()
new_dt_boxes = []
new_rec_res = []
for box, rec in zip(dt_boxes, rec_res):
if np.max(box[1::2]) < y1:
continue
new_dt_boxes.append(box)
new_rec_res.append(rec)
return new_dt_boxes, new_rec_res
......@@ -16,7 +16,7 @@ import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
......@@ -87,6 +87,7 @@ class TableStructurer(object):
utility.create_predictor(args, 'table', logger)
def __call__(self, img):
starttime = time.time()
ori_im = img.copy()
data = {'image': img}
data = transform(data, self.preprocess_op)
......@@ -95,7 +96,6 @@ class TableStructurer(object):
return None, 0
img = np.expand_dims(img, axis=0)
img = img.copy()
starttime = time.time()
self.input_tensor.copy_from_cpu(img)
self.predictor.run()
......@@ -126,7 +126,6 @@ def main(args):
table_structurer = TableStructurer(args)
count = 0
total_time = 0
use_xywh = args.table_algorithm in ['TableMaster']
os.makedirs(args.output, exist_ok=True)
with open(
os.path.join(args.output, 'infer.txt'), mode='w',
......@@ -146,7 +145,7 @@ def main(args):
f_w.write("result: {}, {}\n".format(structure_str_list,
bbox_list_str))
img = draw_rectangle(image_file, bbox_list, use_xywh)
img = draw_rectangle(image_file, bbox_list)
img_save_path = os.path.join(args.output,
os.path.basename(image_file))
cv2.imwrite(img_save_path, img)
......
......@@ -18,20 +18,23 @@ import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import copy
import logging
import numpy as np
import time
import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det
import tools.infer.utility as utility
from tools.infer.predict_system import sorted_boxes
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
from ppstructure.table.matcher import distance, compute_iou
from ppstructure.table.matcher import TableMatch
from ppstructure.table.table_master_match import TableMasterMatcher
from ppstructure.utility import parse_args
import ppstructure.table.predict_structure as predict_strture
......@@ -55,11 +58,20 @@ def expand(pix, det_box, shape):
class TableSystem(object):
def __init__(self, args, text_detector=None, text_recognizer=None):
if not args.show_log:
logger.setLevel(logging.INFO)
self.text_detector = predict_det.TextDetector(
args) if text_detector is None else text_detector
self.text_recognizer = predict_rec.TextRecognizer(
args) if text_recognizer is None else text_recognizer
self.table_structurer = predict_strture.TableStructurer(args)
if args.table_algorithm in ['TableMaster']:
self.match = TableMasterMatcher()
else:
self.match = TableMatch()
self.benchmark = args.benchmark
self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
args, 'table', logger)
......@@ -85,16 +97,47 @@ class TableSystem(object):
def __call__(self, img, return_ocr_result_in_table=False):
result = dict()
ori_im = img.copy()
time_dict = {'det': 0, 'rec': 0, 'table': 0, 'all': 0, 'match': 0}
start = time.time()
structure_res, elapse = self._structure(copy.deepcopy(img))
time_dict['table'] = elapse
dt_boxes, rec_res, det_elapse, rec_elapse = self._ocr(
copy.deepcopy(img))
time_dict['det'] = det_elapse
time_dict['rec'] = rec_elapse
if return_ocr_result_in_table:
result['boxes'] = dt_boxes #[x.tolist() for x in dt_boxes]
result['rec_res'] = rec_res
tic = time.time()
pred_html = self.match(structure_res, dt_boxes, rec_res)
toc = time.time()
time_dict['match'] = toc - tic
# pred_html = self.match(1, 1, 1,img_name)
result['html'] = pred_html
if self.benchmark:
self.autolog.times.end(stamp=True)
end = time.time()
time_dict['all'] = end - start
if self.benchmark:
self.autolog.times.stamp()
return result, time_dict
def _structure(self, img):
if self.benchmark:
self.autolog.times.start()
structure_res, elapse = self.table_structurer(copy.deepcopy(img))
return structure_res, elapse
def _ocr(self, img):
if self.benchmark:
self.autolog.times.stamp()
dt_boxes, elapse = self.text_detector(copy.deepcopy(img))
dt_boxes, det_elapse = self.text_detector(copy.deepcopy(img))
dt_boxes = sorted_boxes(dt_boxes)
if return_ocr_result_in_table:
result['boxes'] = [x.tolist() for x in dt_boxes]
r_boxes = []
for box in dt_boxes:
x_min = box[:, 0].min() - 1
......@@ -105,125 +148,20 @@ class TableSystem(object):
r_boxes.append(box)
dt_boxes = np.array(r_boxes)
logger.debug("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse))
len(dt_boxes), det_elapse))
if dt_boxes is None:
return None, None
img_crop_list = []
for i in range(len(dt_boxes)):
det_box = dt_boxes[i]
x0, y0, x1, y1 = expand(2, det_box, ori_im.shape)
text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :]
x0, y0, x1, y1 = expand(2, det_box, img.shape)
text_rect = img[int(y0):int(y1), int(x0):int(x1), :]
img_crop_list.append(text_rect)
rec_res, elapse = self.text_recognizer(img_crop_list)
rec_res, rec_elapse = self.text_recognizer(img_crop_list)
logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse))
if self.benchmark:
self.autolog.times.stamp()
if return_ocr_result_in_table:
result['rec_res'] = rec_res
pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
result['html'] = pred_html
if self.benchmark:
self.autolog.times.end(stamp=True)
return result
def rebuild_table(self, structure_res, dt_boxes, rec_res):
pred_structures, pred_bboxes = structure_res
dt_boxes, rec_res = self.filter_ocr_result(pred_bboxes,dt_boxes, rec_res)
matched_index = self.match_result(dt_boxes, pred_bboxes)
pred_html, pred = self.get_pred_html(pred_structures, matched_index,
rec_res)
return pred_html, pred
def filter_ocr_result(self, pred_bboxes,dt_boxes, rec_res):
y1 = pred_bboxes[:,1::2].min()
new_dt_boxes = []
new_rec_res = []
for box,rec in zip(dt_boxes, rec_res):
if np.max(box[1::2]) < y1:
continue
new_dt_boxes.append(box)
new_rec_res.append(rec)
return new_dt_boxes, new_rec_res
def match_result(self, dt_boxes, pred_bboxes):
matched = {}
for i, gt_box in enumerate(dt_boxes):
# gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
distances = []
for j, pred_box in enumerate(pred_bboxes):
distances.append((distance(gt_box, pred_box),
1. - compute_iou(gt_box, pred_box)
)) # 获取两两cell之间的L1距离和 1- IOU
sorted_distances = distances.copy()
# 根据距离和IOU挑选最"近"的cell
sorted_distances = sorted(
sorted_distances, key=lambda item: (item[1], item[0]))
if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
else:
matched[distances.index(sorted_distances[0])].append(i)
return matched
def get_pred_html(self, pred_structures, matched_index, ocr_contents):
end_html = []
td_index = 0
for tag in pred_structures:
if '</td>' in tag:
if td_index in matched_index.keys():
b_with = False
if '<b>' in ocr_contents[matched_index[td_index][
0]] and len(matched_index[td_index]) > 1:
b_with = True
end_html.extend('<b>')
for i, td_index_index in enumerate(matched_index[td_index]):
content = ocr_contents[td_index_index][0]
if len(matched_index[td_index]) > 1:
if len(content) == 0:
continue
if content[0] == ' ':
content = content[1:]
if '<b>' in content:
content = content[3:]
if '</b>' in content:
content = content[:-4]
if len(content) == 0:
continue
if i != len(matched_index[
td_index]) - 1 and ' ' != content[-1]:
content += ' '
end_html.extend(content)
if b_with:
end_html.extend('</b>')
end_html.append(tag)
td_index += 1
else:
end_html.append(tag)
return ''.join(end_html), end_html
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
return _boxes
len(rec_res), rec_elapse))
return dt_boxes, rec_res, det_elapse, rec_elapse
def to_excel(html_table, excel_path):
......@@ -249,7 +187,7 @@ def main(args):
logger.error("error in loading image:{}".format(image_file))
continue
starttime = time.time()
pred_res = text_sys(img)
pred_res, _ = text_sys(img)
pred_html = pred_res['html']
logger.info(pred_html)
to_excel(pred_html, excel_path)
......
此差异已折叠。
......@@ -32,6 +32,7 @@ def init_args():
type=str,
default="../ppocr/utils/dict/table_structure_dict.txt")
# params for layout
parser.add_argument("--layout_model_dir", type=str)
parser.add_argument(
"--layout_path_model",
type=str,
......@@ -87,7 +88,7 @@ def draw_structure_result(image, result, font_path):
image = Image.fromarray(image)
boxes, txts, scores = [], [], []
for region in result:
if region['type'] == 'Table':
if region['type'] == 'table':
pass
else:
for text_result in region['res']:
......
......@@ -19,8 +19,6 @@ Global:
character_type: en
max_text_length: 800
infer_mode: False
process_total_num: 0
process_cut_num: 0
Optimizer:
name: Adam
......
......@@ -16,8 +16,6 @@ Global:
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
infer_mode: false
max_text_length: 500
process_total_num: 0
process_cut_num: 0
Optimizer:
......@@ -86,7 +84,7 @@ Train:
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
box_format: 'xywh'
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
......@@ -120,7 +118,7 @@ Eval:
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
box_format: 'xywh'
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
......
......@@ -65,9 +65,11 @@ class TextSystem(object):
self.crop_image_res_index += bbox_num
def __call__(self, img, cls=True):
time_dict = {'det': 0, 'rec': 0, 'csl': 0, 'all': 0}
start = time.time()
ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img)
time_dict['det'] = elapse
logger.debug("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse))
if dt_boxes is None:
......@@ -83,10 +85,12 @@ class TextSystem(object):
if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list)
time_dict['cls'] = elapse
logger.debug("cls num : {}, elapse : {}".format(
len(img_crop_list), elapse))
rec_res, elapse = self.text_recognizer(img_crop_list)
time_dict['rec'] = elapse
logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse))
if self.args.save_crop_res:
......@@ -98,7 +102,9 @@ class TextSystem(object):
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_result)
return filter_boxes, filter_rec_res
end = time.time()
time_dict['all'] = end - start
return filter_boxes, filter_rec_res, time_dict
def sorted_boxes(dt_boxes):
......@@ -133,9 +139,11 @@ def main(args):
os.makedirs(draw_img_save_dir, exist_ok=True)
save_results = []
logger.info("In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
"if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320")
logger.info(
"In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
"if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320"
)
# warm up 10 times
if args.warmup:
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
......@@ -155,7 +163,7 @@ def main(args):
logger.debug("error in loading image:{}".format(image_file))
continue
starttime = time.time()
dt_boxes, rec_res = text_sys(img)
dt_boxes, rec_res, time_dict = text_sys(img)
elapse = time.time() - starttime
total_time += elapse
......@@ -198,7 +206,10 @@ def main(args):
text_sys.text_detector.autolog.report()
text_sys.text_recognizer.autolog.report()
with open(os.path.join(draw_img_save_dir, "system_results.txt"), 'w', encoding='utf-8') as f:
with open(
os.path.join(draw_img_save_dir, "system_results.txt"),
'w',
encoding='utf-8') as f:
f.writelines(save_results)
......
......@@ -155,6 +155,8 @@ def create_predictor(args, mode, logger):
model_dir = args.table_model_dir
elif mode == 'ser':
model_dir = args.ser_model_dir
elif mode == 'layout':
model_dir = args.layout_model_dir
else:
model_dir = args.e2e_model_dir
......
......@@ -56,7 +56,6 @@ def main(config, device, logger, vdl_writer):
model = build_model(config['Architecture'])
algorithm = config['Architecture']['algorithm']
use_xywh = algorithm in ['TableMaster']
load_model(config, model)
......@@ -106,7 +105,7 @@ def main(config, device, logger, vdl_writer):
f_w.write("result: {}, {}\n".format(structure_str_list,
bbox_list_str))
img = draw_rectangle(file, bbox_list, use_xywh)
img = draw_rectangle(file, bbox_list)
cv2.imwrite(
os.path.join(save_res_path, os.path.basename(file)), img)
logger.info("success!")
......
......@@ -154,6 +154,7 @@ def check_xpu(use_xpu):
except Exception as e:
pass
def to_float32(preds):
if isinstance(preds, dict):
for k in preds:
......@@ -173,6 +174,7 @@ def to_float32(preds):
preds = preds.astype(paddle.float32)
return preds
def train(config,
train_dataloader,
valid_dataloader,
......@@ -596,7 +598,7 @@ def preprocess(is_train=False):
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN'
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'SLANet'
]
if use_xpu:
......
......@@ -119,6 +119,10 @@ def main(config, device, logger, vdl_writer):
config['Loss']['ignore_index'] = char_num - 1
model = build_model(config['Architecture'])
use_sync_bn = config["Global"].get("use_sync_bn", False)
if use_sync_bn:
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
logger.info('convert_sync_batchnorm')
if config['Global']['distributed']:
model = paddle.DataParallel(model)
......@@ -157,7 +161,8 @@ def main(config, device, logger, vdl_writer):
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
model, optimizer = paddle.amp.decorate(models=model, optimizers=optimizer, level='O2', master_weight=True)
model, optimizer = paddle.amp.decorate(
models=model, optimizers=optimizer, level='O2', master_weight=True)
else:
scaler = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册