未验证 提交 e93735a2 编写于 作者: M MissPenguin 提交者: GitHub

Merge pull request #3083 from WenmuZhou/table1

[DO NOT MERGE]Table
include LICENSE.txt
include LICENSE
include README.md
recursive-include ppocr/utils *.txt utility.py logging.py
recursive-include ppocr/utils *.txt utility.py logging.py network.py
recursive-include ppocr/data/ *.py
recursive-include ppocr/postprocess *.py
recursive-include tools/infer *.py
......
......@@ -355,3 +355,4 @@ im_show.save('result.jpg')
| det | 前向时使用启动检测 | TRUE |
| rec | 前向时是否启动识别 | TRUE |
| cls | 前向时是否启动分类 (命令行模式下使用use_angle_cls控制前向是否启动分类) | FALSE |
| show_log | 是否打印det和rec等信息 | FALSE |
......@@ -362,3 +362,5 @@ im_show.save('result.jpg')
| det | Enable detction when `ppocr.ocr` func exec | TRUE |
| rec | Enable recognition when `ppocr.ocr` func exec | TRUE |
| cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE |
| show_log | Whether to print log in det and rec
| FALSE |
\ No newline at end of file
......@@ -19,17 +19,16 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, ''))
import cv2
import logging
import numpy as np
from pathlib import Path
import tarfile
import requests
from tqdm import tqdm
from tools.infer import predict_system
from ppocr.utils.logging import get_logger
logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url
from tools.infer.utility import draw_ocr, init_args, str2bool
__all__ = ['PaddleOCR']
......@@ -37,84 +36,84 @@ __all__ = ['PaddleOCR']
model_urls = {
'det': {
'ch':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
'en':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
},
'rec': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
},
'en': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/en_dict.txt'
},
'french': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/french_dict.txt'
},
'german': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/german_dict.txt'
},
'korean': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/korean_dict.txt'
},
'japan': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt'
},
'chinese_cht': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
},
'ta': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ta_dict.txt'
},
'te': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/te_dict.txt'
},
'ka': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ka_dict.txt'
},
'latin': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/latin_dict.txt'
},
'arabic': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/arabic_dict.txt'
},
'cyrillic': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
},
'devanagari': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
}
},
'cls':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
}
SUPPORT_DET_MODEL = ['DB']
......@@ -123,50 +122,6 @@ SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/")
def download_with_progressbar(url, save_path):
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(save_path, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
logger.error("Something went wrong while downloading models")
sys.exit(0)
def maybe_download(model_storage_directory, url):
# using custom model
tar_file_name_list = [
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
]
if not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdiparams')
) or not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdmodel')):
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True)
download_with_progressbar(url, tmp_path)
with tarfile.open(tmp_path, 'r') as tarObj:
for member in tarObj.getmembers():
filename = None
for tar_file_name in tar_file_name_list:
if tar_file_name in member.name:
filename = tar_file_name
if filename is None:
continue
file = tarObj.extractfile(member)
with open(
os.path.join(model_storage_directory, filename),
'wb') as f:
f.write(file.read())
os.remove(tmp_path)
def parse_args(mMain=True):
import argparse
parser = init_args()
......@@ -194,10 +149,12 @@ class PaddleOCR(predict_system.TextSystem):
args:
**kwargs: other params show in paddleocr --help
"""
postprocess_params = parse_args(mMain=False)
postprocess_params.__dict__.update(**kwargs)
self.use_angle_cls = postprocess_params.use_angle_cls
lang = postprocess_params.lang
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
if not params.show_log:
logger.setLevel(logging.INFO)
self.use_angle_cls = params.use_angle_cls
lang = params.lang
latin_lang = [
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
......@@ -223,46 +180,45 @@ class PaddleOCR(predict_system.TextSystem):
lang = "devanagari"
assert lang in model_urls[
'rec'], 'param lang must in {}, but got {}'.format(
model_urls['rec'].keys(), lang)
model_urls['rec'].keys(), lang)
if lang == "ch":
det_lang = "ch"
else:
det_lang = "en"
use_inner_dict = False
if postprocess_params.rec_char_dict_path is None:
if params.rec_char_dict_path is None:
use_inner_dict = True
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
params.rec_char_dict_path = model_urls['rec'][lang][
'dict_path']
# init model dir
if postprocess_params.det_model_dir is None:
postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION,
'det', det_lang)
if postprocess_params.rec_model_dir is None:
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
'rec', lang)
if postprocess_params.cls_model_dir is None:
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
print(postprocess_params)
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
os.path.join(BASE_DIR, VERSION, 'det', det_lang),
model_urls['det'][det_lang])
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
os.path.join(BASE_DIR, VERSION, 'rec', lang),
model_urls['rec'][lang]['url'])
params.cls_model_dir, cls_url = confirm_model_dir_url(params.cls_model_dir,
os.path.join(BASE_DIR, VERSION, 'cls'),
model_urls['cls'])
# download model
maybe_download(postprocess_params.det_model_dir,
model_urls['det'][det_lang])
maybe_download(postprocess_params.rec_model_dir,
model_urls['rec'][lang]['url'])
maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
maybe_download(params.det_model_dir, det_url)
maybe_download(params.rec_model_dir, rec_url)
maybe_download(params.cls_model_dir, cls_url)
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
if params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
sys.exit(0)
if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL:
if params.rec_algorithm not in SUPPORT_REC_MODEL:
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
sys.exit(0)
if use_inner_dict:
postprocess_params.rec_char_dict_path = str(
Path(__file__).parent / postprocess_params.rec_char_dict_path)
params.rec_char_dict_path = str(
Path(__file__).parent / params.rec_char_dict_path)
print(params)
# init det_model and rec_model
super().__init__(postprocess_params)
super().__init__(params)
def ocr(self, img, det=True, rec=True, cls=True):
"""
......@@ -320,7 +276,7 @@ def main():
# for cmd
args = parse_args(mMain=True)
image_dir = args.image_dir
if image_dir.startswith('http'):
if is_link(image_dir):
download_with_progressbar(image_dir, 'tmp.jpg')
image_file_list = ['tmp.jpg']
else:
......
......@@ -29,6 +29,7 @@ from .label_ops import *
from .east_process import *
from .sast_process import *
from .pg_process import *
from .gen_table_mask import *
def transform(data, ops=None):
......
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import six
import cv2
import numpy as np
class GenTableMask(object):
""" gen table mask """
def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
self.shrink_h_max = 5
self.shrink_w_max = 5
self.mask_type = mask_type
def projection(self, erosion, h, w, spilt_threshold=0):
# 水平投影
projection_map = np.ones_like(erosion)
project_val_array = [0 for _ in range(0, h)]
for j in range(0, h):
for i in range(0, w):
if erosion[j, i] == 255:
project_val_array[j] += 1
# 根据数组,获取切割点
start_idx = 0 # 记录进入字符区的索引
end_idx = 0 # 记录进入空白区域的索引
in_text = False # 是否遍历到了字符区内
box_list = []
for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
continue
box_list.append((start_idx, end_idx + 1))
if in_text:
box_list.append((start_idx, h - 1))
# 绘制投影直方图
for j in range(0, h):
for i in range(0, project_val_array[j]):
projection_map[j, i] = 0
return box_list, projection_map
def projection_cx(self, box_img):
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
h, w = box_gray_img.shape
# 灰度图片进行二值化处理
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
# 纵向腐蚀
if h < w:
kernel = np.ones((2, 1), np.uint8)
erode = cv2.erode(thresh1, kernel, iterations=1)
else:
erode = thresh1
# 水平膨胀
kernel = np.ones((1, 5), np.uint8)
erosion = cv2.dilate(erode, kernel, iterations=1)
# 水平投影
projection_map = np.ones_like(erosion)
project_val_array = [0 for _ in range(0, h)]
for j in range(0, h):
for i in range(0, w):
if erosion[j, i] == 255:
project_val_array[j] += 1
# 根据数组,获取切割点
start_idx = 0 # 记录进入字符区的索引
end_idx = 0 # 记录进入空白区域的索引
in_text = False # 是否遍历到了字符区内
box_list = []
spilt_threshold = 0
for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
continue
box_list.append((start_idx, end_idx + 1))
if in_text:
box_list.append((start_idx, h - 1))
# 绘制投影直方图
for j in range(0, h):
for i in range(0, project_val_array[j]):
projection_map[j, i] = 0
split_bbox_list = []
if len(box_list) > 1:
for i, (h_start, h_end) in enumerate(box_list):
if i == 0:
h_start = 0
if i == len(box_list):
h_end = h
word_img = erosion[h_start:h_end + 1, :]
word_h, word_w = word_img.shape
w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h)
w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
if h_start > 0:
h_start -= 1
h_end += 1
word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
split_bbox_list.append([w_start, h_start, w_end, h_end])
else:
split_bbox_list.append([0, 0, w, h])
return split_bbox_list
def shrink_bbox(self, bbox):
left, top, right, bottom = bbox
sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
left_new = left + sh_w
right_new = right - sh_w
top_new = top + sh_h
bottom_new = bottom - sh_h
if left_new >= right_new:
left_new = left
right_new = right
if top_new >= bottom_new:
top_new = top
bottom_new = bottom
return [left_new, top_new, right_new, bottom_new]
def __call__(self, data):
img = data['image']
cells = data['cells']
height, width = img.shape[0:2]
if self.mask_type == 1:
mask_img = np.zeros((height, width), dtype=np.float32)
else:
mask_img = np.zeros((height, width, 3), dtype=np.float32)
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
bbox = cells[cno]['bbox']
left, top, right, bottom = bbox
box_img = img[top:bottom, left:right, :].copy()
split_bbox_list = self.projection_cx(box_img)
for sno in range(len(split_bbox_list)):
split_bbox_list[sno][0] += left
split_bbox_list[sno][1] += top
split_bbox_list[sno][2] += left
split_bbox_list[sno][3] += top
for sno in range(len(split_bbox_list)):
left, top, right, bottom = split_bbox_list[sno]
left, top, right, bottom = self.shrink_bbox([left, top, right, bottom])
if self.mask_type == 1:
mask_img[top:bottom, left:right] = 1.0
data['mask_img'] = mask_img
else:
mask_img[top:bottom, left:right, :] = (255, 255, 255)
data['image'] = mask_img
return data
class ResizeTableImage(object):
def __init__(self, max_len, **kwargs):
super(ResizeTableImage, self).__init__()
self.max_len = max_len
def get_img_bbox(self, cells):
bbox_list = []
if len(cells) == 0:
return bbox_list
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
bbox = cells[cno]['bbox']
bbox_list.append(bbox)
return bbox_list
def resize_img_table(self, img, bbox_list, max_len):
height, width = img.shape[0:2]
ratio = max_len / (max(height, width) * 1.0)
resize_h = int(height * ratio)
resize_w = int(width * ratio)
img_new = cv2.resize(img, (resize_w, resize_h))
bbox_list_new = []
for bno in range(len(bbox_list)):
left, top, right, bottom = bbox_list[bno].copy()
left = int(left * ratio)
top = int(top * ratio)
right = int(right * ratio)
bottom = int(bottom * ratio)
bbox_list_new.append([left, top, right, bottom])
return img_new, bbox_list_new
def __call__(self, data):
img = data['image']
if 'cells' not in data:
cells = []
else:
cells = data['cells']
bbox_list = self.get_img_bbox(cells)
img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
data['image'] = img_new
cell_num = len(cells)
bno = 0
for cno in range(cell_num):
if "bbox" in data['cells'][cno]:
data['cells'][cno]['bbox'] = bbox_list_new[bno]
bno += 1
data['max_len'] = self.max_len
return data
class PaddingTableImage(object):
def __init__(self, **kwargs):
super(PaddingTableImage, self).__init__()
def __call__(self, data):
img = data['image']
max_len = data['max_len']
padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
height, width = img.shape[0:2]
padding_img[0:height, 0:width, :] = img.copy()
data['image'] = padding_img
return data
\ No newline at end of file
......@@ -81,7 +81,7 @@ class NormalizeImage(object):
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std
img.astype('float32') * self.scale - self.mean) / self.std
return data
......@@ -163,7 +163,7 @@ class DetResizeForTest(object):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
h, w, _ = img.shape
h, w, c = img.shape
# limit the max side
if self.limit_type == 'max':
......@@ -174,7 +174,7 @@ class DetResizeForTest(object):
ratio = float(limit_side_len) / w
else:
ratio = 1.
else:
elif self.limit_type == 'min':
if min(h, w) < limit_side_len:
if h < w:
ratio = float(limit_side_len) / h
......@@ -182,6 +182,10 @@ class DetResizeForTest(object):
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'resize_long':
ratio = float(limit_side_len) / max(h,w)
else:
raise Exception('not support limit type, image ')
resize_h = int(h * ratio)
resize_w = int(w * ratio)
......
......@@ -24,7 +24,8 @@ __all__ = ['build_post_process']
from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
TableLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
......@@ -33,7 +34,7 @@ def build_post_process(config, global_config=None):
support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode'
'DistillationCTCLabelDecode', 'TableLabelDecode'
]
config = copy.deepcopy(config)
......
......@@ -44,16 +44,16 @@ class BaseRecLabelDecode(object):
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
elif character_type in support_character_type:
self.character_str = ""
self.character_str = []
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
character_type)
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str += line
self.character_str.append(line)
if use_space_char:
self.character_str += " "
self.character_str.append(" ")
dict_character = list(self.character_str)
else:
......@@ -319,3 +319,138 @@ class SRNLabelDecode(BaseRecLabelDecode):
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
class TableLabelDecode(object):
""" """
def __init__(self,
character_dict_path,
**kwargs):
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
self.dict_idx_character = {}
for i, char in enumerate(list_character):
self.dict_idx_character[i] = char
self.dict_character[char] = i
self.dict_elem = {}
self.dict_idx_elem = {}
for i, elem in enumerate(list_elem):
self.dict_idx_elem[i] = elem
self.dict_elem[elem] = i
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\n").split("\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\n")
list_character.append(character)
for eno in range(1 + character_num, 1 + character_num + elem_num):
elem = lines[eno].decode('utf-8').strip("\n")
list_elem.append(elem)
return list_character, list_elem
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def __call__(self, preds):
structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds']
if isinstance(structure_probs,paddle.Tensor):
structure_probs = structure_probs.numpy()
if isinstance(loc_preds,paddle.Tensor):
loc_preds = loc_preds.numpy()
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
structure_probs, 'elem')
res_html_code_list = []
res_loc_list = []
batch_num = len(structure_str)
for bno in range(batch_num):
res_loc = []
for sno in range(len(structure_str[bno])):
text = structure_str[bno][sno]
if text in ['<td>', '<td']:
pos = structure_pos[bno][sno]
res_loc.append(loc_preds[bno, pos])
res_html_code = ''.join(structure_str[bno])
res_loc = np.array(res_loc)
res_html_code_list.append(res_html_code)
res_loc_list.append(res_loc)
return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
def decode(self, text_index, structure_probs, char_or_elem):
"""convert text-label into text-index.
"""
if char_or_elem == "char":
current_dict = self.dict_idx_character
else:
current_dict = self.dict_idx_elem
ignored_tokens = self.get_ignored_tokens('elem')
beg_idx, end_idx = ignored_tokens
result_list = []
result_pos_list = []
result_score_list = []
result_elem_idx_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
elem_pos_list = []
elem_idx_list = []
score_list = []
for idx in range(len(text_index[batch_idx])):
tmp_elem_idx = int(text_index[batch_idx][idx])
if idx > 0 and tmp_elem_idx == end_idx:
break
if tmp_elem_idx in ignored_tokens:
continue
char_list.append(current_dict[tmp_elem_idx])
elem_pos_list.append(idx)
score_list.append(structure_probs[batch_idx, idx])
elem_idx_list.append(tmp_elem_idx)
result_list.append(char_list)
result_pos_list.append(elem_pos_list)
result_score_list.append(score_list)
result_elem_idx_list.append(elem_idx_list)
return result_list, result_pos_list, result_score_list, result_elem_idx_list
def get_ignored_tokens(self, char_or_elem):
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
if char_or_elem == "char":
if beg_or_end == "beg":
idx = self.dict_character[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_character[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
% beg_or_end
elif char_or_elem == "elem":
if beg_or_end == "beg":
idx = self.dict_elem[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_elem[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
% beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
</overline>
α

$
ω
ψ
χ
(
υ
σ
,
ρ
ε
0
4
8
b
<
Ψ
Ω
D
3
Π
H
</strike>
L
Φ
Χ
θ
P
κ
λ
μ
T
ξ
X
β
γ
δ
\
ζ
η
`
d
<strike>
h
f
l
Θ
p
t
</sub>
x
Β
Γ
Δ
|
ǂ
ɛ
j
̧
̌
«
#
</b>
'
Ι
+
/
·
7
;
?
C
÷
G
K
<sup>
O
S
С
W
Α
[
_
c
z
g
<i>
o
<sub>
s
w
φ
ʹ
{
»
̆
e
ˆ
τ
ι
Ø
ß
×
˃
˂
"
i
&
π
*
æ
.
ø
Q
6
:
>
a
B
F
J
̄
N
R
V
<overline>
Z
^
¤
¥
§
<underline>
¢
£
­
Λ
©
n
r
°
±
v
<b>
k
~
̇
@
ł
®
!
</sup>
%
)
-
1
5
9
=
А
A
Σ
E
I
M
m
̨
</i>
U
Y
]
̸
2
̂
̀
́
̊
̈
q
u
ı
y
</underline>
̃
}
ν
此差异已折叠。
......@@ -22,7 +22,7 @@ logger_initialized = {}
@functools.lru_cache()
def get_logger(name='root', log_file=None, log_level=logging.INFO):
def get_logger(name='root', log_file=None, log_level=logging.DEBUG):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import tarfile
import requests
from tqdm import tqdm
from ppocr.utils.logging import get_logger
def download_with_progressbar(url, save_path):
logger = get_logger()
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(save_path, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
logger.error("Something went wrong while downloading models")
sys.exit(0)
def maybe_download(model_storage_directory, url):
# using custom model
tar_file_name_list = [
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
]
if not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdiparams')
) or not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdmodel')):
assert url.endswith('.tar'), 'Only supports tar compressed package'
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True)
download_with_progressbar(url, tmp_path)
with tarfile.open(tmp_path, 'r') as tarObj:
for member in tarObj.getmembers():
filename = None
for tar_file_name in tar_file_name_list:
if tar_file_name in member.name:
filename = tar_file_name
if filename is None:
continue
file = tarObj.extractfile(member)
with open(
os.path.join(model_storage_directory, filename),
'wb') as f:
f.write(file.read())
os.remove(tmp_path)
def is_link(s):
return s is not None and s.startswith('http')
def confirm_model_dir_url(model_dir, default_model_dir, default_url):
url = default_url
if model_dir is None or is_link(model_dir):
if is_link(model_dir):
url = model_dir
file_name = url.split('/')[-1][:-4]
model_dir = default_model_dir
model_dir = os.path.join(model_dir, file_name)
return model_dir, url
include LICENSE
include README.md
recursive-include ppocr/utils *.txt utility.py logging.py network.py
recursive-include ppocr/data/ *.py
recursive-include ppocr/postprocess *.py
recursive-include tools/infer *.py
recursive-include ppstructure *.py
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .paddlestructure import PaddleStructure, draw_result, to_excel
__all__ = ['PaddleStructure', 'draw_result', 'to_excel']
# PaddleStructure
## 1. Introduction to pipeline
PaddleStructure is a toolkit for complex layout text OCR, the process is as follows
![pipeline](../doc/table/pipeline.png)
In PaddleStructure, the image will be analyzed by layoutparser first. In the layout analysis, the area in the image will be classified, and the OCR process will be carried out according to the category.
Currently layoutparser will output five categories:
1. Text
2. Title
3. Figure
4. List
5. Table
Types 1-4 follow the traditional OCR process, and 5 follow the Table OCR process.
## 2. LayoutParser
## 3. Table OCR
[doc](table/README.md)
## 4. PaddleStructure whl package introduction
### 4.1 Use
4.1.1 Use by code
```python
import cv2
from paddlestructure import PaddleStructure,draw_result
table_engine = PaddleStructure(
output='./output/table',
show_log=True)
img_path = '../doc/table/1.png'
img = cv2.imread(img_path)
result = table_engine(img)
for line in result:
print(line)
from PIL import Image
font_path = 'path/tp/PaddleOCR/doc/fonts/simfang.ttf'
image = Image.open(img_path).convert('RGB')
im_show = draw_result(image, result,font_path=font_path)
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
4.1.2 Use by command line
```bash
paddlestructure --image_dir=../doc/table/1.png
```
### 参数说明
大部分参数和paddleocr whl包保持一致,见 [whl包文档](../doc/doc_ch/whl.md)
| 字段 | 说明 | 默认值 |
|------------------------|------------------------------------------------------|------------------|
| output | excel和识别结果保存的地址 | ./output/table |
| structure_max_len | structure模型预测时,图像的长边resize尺度 | 488 |
| structure_model_dir | structure inference 模型地址 | None |
| structure_char_type | structure 模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.tx |
# PaddleStructure
## 1. pipeline介绍
PaddleStructure 是一个用于复杂板式文字OCR的工具包,流程如下
![pipeline](../doc/table/pipeline.png)
在PaddleStructure中,图片会先经由layoutparser进行版面分析,在版面分析中,会对图片里的区域进行分类,根据根据类别进行对于的ocr流程。
目前layoutparser会输出五个类别:
1. Text
2. Title
3. Figure
4. List
5. Table
1-4类走传统的OCR流程,5走表格的OCR流程。
## 2. LayoutParser
## 3. Table OCR
[文档](table/README_ch.md)
## 4. PaddleStructure whl包介绍
### 4.1 使用
4.1.1 代码使用
```python
import cv2
from paddlestructure import PaddleStructure,draw_result
table_engine = PaddleStructure(
output='./output/table',
show_log=True)
img_path = '../doc/table/1.png'
img = cv2.imread(img_path)
result = table_engine(img)
for line in result:
print(line)
from PIL import Image
font_path = 'path/tp/PaddleOCR/doc/fonts/simfang.ttf'
image = Image.open(img_path).convert('RGB')
im_show = draw_result(image, result,font_path=font_path)
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
4.1.2 命令行使用
```bash
paddlestructure --image_dir=../doc/table/1.png
```
### 参数说明
大部分参数和paddleocr whl包保持一致,见 [whl包文档](../doc/doc_ch/whl.md)
| 字段 | 说明 | 默认值 |
|------------------------|------------------------------------------------------|------------------|
| output | excel和识别结果保存的地址 | ./output/table |
| structure_max_len | structure模型预测时,图像的长边resize尺度 | 488 |
| structure_model_dir | structure inference 模型地址 | None |
| structure_char_type | structure 模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.tx |
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
import cv2
import numpy as np
from pathlib import Path
from ppocr.utils.logging import get_logger
from test.predict_system import OCRSystem, save_res
from test.table.predict_table import to_excel
from test.utility import init_args, draw_result
logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from ppocr.utils.network import maybe_download, download_with_progressbar, confirm_model_dir_url, is_link
__all__ = ['PaddleStructure', 'draw_result', 'to_excel']
VERSION = '2.1'
BASE_DIR = os.path.expanduser("~/.paddlestructure/")
model_urls = {
'det': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar',
'rec': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
'structure': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar'
}
def parse_args(mMain=True):
import argparse
parser = init_args()
parser.add_help = mMain
for action in parser._actions:
if action.dest in ['rec_char_dict_path', 'structure_char_dict_path']:
action.default = None
if mMain:
return parser.parse_args()
else:
inference_args_dict = {}
for action in parser._actions:
inference_args_dict[action.dest] = action.default
return argparse.Namespace(**inference_args_dict)
class PaddleStructure(OCRSystem):
def __init__(self, **kwargs):
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
if not params.show_log:
logger.setLevel(logging.INFO)
params.use_angle_cls = False
# init model dir
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
os.path.join(BASE_DIR, VERSION, 'det'),
model_urls['det'])
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
os.path.join(BASE_DIR, VERSION, 'rec'),
model_urls['rec'])
params.structure_model_dir, structure_url = confirm_model_dir_url(params.structure_model_dir,
os.path.join(BASE_DIR, VERSION, 'structure'),
model_urls['structure'])
# download model
maybe_download(params.det_model_dir, det_url)
maybe_download(params.rec_model_dir, rec_url)
maybe_download(params.structure_model_dir, structure_url)
if params.rec_char_dict_path is None:
params.rec_char_type = 'EN'
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')):
params.rec_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')
else:
params.rec_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_dict.txt')
if params.structure_char_dict_path is None:
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')):
params.structure_char_dict_path = str(
Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')
else:
params.structure_char_dict_path = str(
Path(__file__).parent.parent / 'ppocr/utils/dict/table_structure_dict.txt')
print(params)
super().__init__(params)
def __call__(self, img):
if isinstance(img, str):
# download net image
if img.startswith('http'):
download_with_progressbar(img, 'tmp.jpg')
img = 'tmp.jpg'
image_file = img
img, flag = check_and_read_gif(image_file)
if not flag:
with open(image_file, 'rb') as f:
np_arr = np.frombuffer(f.read(), dtype=np.uint8)
img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
if img is None:
logger.error("error in loading image:{}".format(image_file))
return None
if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
res = super().__call__(img)
return res
def main():
# for cmd
args = parse_args(mMain=True)
image_dir = args.image_dir
save_folder = args.output
if image_dir.startswith('http'):
download_with_progressbar(image_dir, 'tmp.jpg')
image_file_list = ['tmp.jpg']
else:
image_file_list = get_image_file_list(args.image_dir)
if len(image_file_list) == 0:
logger.error('no images find in {}'.format(args.image_dir))
return
structure_engine = PaddleStructure(**(args.__dict__))
for img_path in image_file_list:
img_name = os.path.basename(img_path).split('.')[0]
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
result = structure_engine(img_path)
for item in result:
logger.info(item['res'])
save_res(result, save_folder, img_name)
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import numpy as np
import time
import logging
import layoutparser as lp
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
from tools.infer.predict_system import TextSystem
from test.table.predict_table import TableSystem, to_excel
from test.utility import parse_args, draw_result
logger = get_logger()
class OCRSystem(object):
def __init__(self, args):
args.det_limit_type = 'resize_long'
args.drop_score = 0
if not args.show_log:
logger.setLevel(logging.INFO)
self.text_system = TextSystem(args)
self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
threshold=0.5, enable_mkldnn=args.enable_mkldnn,
enforce_cpu=not args.use_gpu, thread_num=args.cpu_threads)
self.use_angle_cls = args.use_angle_cls
self.drop_score = args.drop_score
def __call__(self, img):
ori_im = img.copy()
layout_res = self.table_layout.detect(img[..., ::-1])
res_list = []
for region in layout_res:
x1, y1, x2, y2 = region.coordinates
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
roi_img = ori_im[y1:y2, x1:x2, :]
if region.type == 'Table':
res = self.table_system(roi_img)
else:
filter_boxes, filter_rec_res = self.text_system(roi_img)
filter_boxes = [x + [x1, y1] for x in filter_boxes]
filter_boxes = [x.reshape(-1).tolist() for x in filter_boxes]
res = (filter_boxes, filter_rec_res)
res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res})
return res_list
def save_res(res, save_folder, img_name):
excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True)
# save res
for region in res:
if region['type'] == 'Table':
excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
to_excel(region['res'], excel_path)
elif region['type'] == 'Figure':
pass
else:
with open(os.path.join(excel_save_folder, 'res.txt'), 'a', encoding='utf8') as f:
for box, rec_res in zip(region['res'][0], region['res'][1]):
f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
def main(args):
image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list
image_file_list = image_file_list[args.process_id::args.total_process_num]
save_folder = args.output
os.makedirs(save_folder, exist_ok=True)
structure_sys = OCRSystem(args)
img_num = len(image_file_list)
for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file)
img_name = os.path.basename(image_file).split('.')[0]
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.error("error in loading image:{}".format(image_file))
continue
starttime = time.time()
res = structure_sys(img)
save_res(res, save_folder, img_name)
draw_img = draw_result(img, res, args.vis_font_path)
cv2.imwrite(os.path.join(save_folder, img_name, 'show.jpg'), draw_img)
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse))
if __name__ == "__main__":
args = parse_args()
if args.use_mp:
p_list = []
total_process_num = args.total_process_num
for process_id in range(total_process_num):
cmd = [sys.executable, "-u"] + sys.argv + [
"--process_id={}".format(process_id),
"--use_mp={}".format(False)
]
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
p_list.append(p)
for p in p_list:
p.wait()
else:
main(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from setuptools import setup
from io import open
import shutil
with open('../requirements.txt', encoding="utf-8-sig") as f:
requirements = f.readlines()
requirements.append('tqdm')
requirements.append('layoutparser')
requirements.append('iopath')
def readme():
with open('api_ch.md', encoding="utf-8-sig") as f:
README = f.read()
return README
shutil.copytree('/table', './test/table')
shutil.copyfile('/predict_system.py', './test/predict_system.py')
shutil.copyfile('/utility.py', './test/utility.py')
shutil.copytree('../ppocr', './ppocr')
shutil.copytree('../tools', './tools')
shutil.copyfile('../LICENSE', './LICENSE')
setup(
name='paddlestructure',
packages=['paddlestructure'],
package_dir={'paddlestructure': ''},
include_package_data=True,
entry_points={"console_scripts": ["paddlestructure= paddlestructure.paddlestructure:main"]},
version='1.0',
install_requires=requirements,
license='Apache License 2.0',
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
long_description=readme(),
long_description_content_type='text/markdown',
url='https://github.com/PaddlePaddle/PaddleOCR',
download_url='https://github.com/PaddlePaddle/PaddleOCR.git',
keywords=[
'ocr textdetection textrecognition paddleocr crnn east star-net rosetta ocrlite db chineseocr chinesetextdetection chinesetextrecognition'
],
classifiers=[
'Intended Audience :: Developers', 'Operating System :: OS Independent',
'Natural Language :: Chinese (Simplified)',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.2',
'Programming Language :: Python :: 3.3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Topic :: Utilities'
], )
shutil.rmtree('ppocr')
shutil.rmtree('tools')
shutil.rmtree('test')
os.remove('LICENSE')
# Table structure and content prediction
## 1. pipeline
The ocr of the table mainly contains three models
1. Single line text detection-DB
2. Single line text recognition-CRNN
3. Table structure and cell coordinate prediction-RARE
The table ocr flow chart is as follows
![tableocr_pipeline](../../doc/table/tableocr_pipeline.png)
1. The coordinates of single-line text is detected by DB model, and then sends it to the recognition model to get the recognition result.
2. The table structure and cell coordinates is predicted by RARE model.
3. The recognition result of the cell is combined by the coordinates, recognition result of the single line and the coordinates of the cell.
4. The cell recognition result and the table structure together construct the html string of the table.
## 2. How to use
### 2.1 Train
TBD
### 2.2 Eval
First cd to the PaddleOCR/ppstructure directory
The table uses TEDS (Tree-Edit-Distance-based Similarity) as the evaluation metric of the model. Before the model evaluation, the three models in the pipeline need to be exported as inference models (we have provided them), and the gt for evaluation needs to be prepared. Examples of gt are as follows:
```json
{"PMC4289340_004_00.png": [["<html>", "<body>", "<table>", "<thead>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</thead>", "<tbody>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</tbody>", "</table>", "</body>", "</html>"], [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]], [["<b>", "F", "e", "a", "t", "u", "r", "e", "</b>"], ["<b>", "G", "b", "3", " ", "+", "</b>"], ["<b>", "G", "b", "3", " ", "-", "</b>"], ["<b>", "P", "a", "t", "i", "e", "n", "t", "s", "</b>"], ["6", "2"], ["4", "5"]]]}
```
In gt json, the key is the image name, the value is the corresponding gt, and gt is a list composed of four items, and each item is
1. HTML string list of table structure
2. The coordinates of each cell (not including the empty text in the cell)
3. The text information in each cell (not including the empty text in the cell)
4. The text information in each cell (including the empty text in the cell)
Use the following command to evaluate. After the evaluation is completed, the teds indicator will be output.
```python
python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --structure_model_dir=path/to/structure_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --structure_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
```
### 2.3 Inference
First cd to the PaddleOCR/ppstructure directory
```python
python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --structure_model_dir=path/to/structure_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --structure_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
After running, the excel sheet of each picture will be saved in the directory specified by the table_output field
\ No newline at end of file
# 表格结构和内容预测
## 1. pipeline
表格的ocr主要包含三个模型
1. 单行文本检测-DB
2. 单行文本识别-CRNN
3. 表格结构和cell坐标预测-RARE
具体流程图如下
![tableocr_pipeline](../../doc/table/tableocr_pipeline.png)
1. 图片由单行文字检测检测模型到单行文字的坐标,然后送入识别模型拿到识别结果。
2. 图片由表格结构和cell坐标预测模型拿到表格的结构信息和单元格的坐标信息。
3. 由单行文字的坐标、识别结果和单元格的坐标一起组合出单元格的识别结果。
4. 单元格的识别结果和表格结构一起构造表格的html字符串。
## 2. 使用
### 2.1 训练
TBD
### 2.2 评估
先cd到PaddleOCR/ppstructure目录下
表格使用 TEDS(Tree-Edit-Distance-based Similarity) 作为模型的评估指标。在进行模型评估之前,需要将pipeline中的三个模型分别导出为inference模型(我们已经提供好),还需要准备评估的gt, gt示例如下:
```json
{"PMC4289340_004_00.png": [["<html>", "<body>", "<table>", "<thead>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</thead>", "<tbody>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</tbody>", "</table>", "</body>", "</html>"], [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]], [["<b>", "F", "e", "a", "t", "u", "r", "e", "</b>"], ["<b>", "G", "b", "3", " ", "+", "</b>"], ["<b>", "G", "b", "3", " ", "-", "</b>"], ["<b>", "P", "a", "t", "i", "e", "n", "t", "s", "</b>"], ["6", "2"], ["4", "5"]]]}
```
json 中,key为图片名,value为对于的gt,gt是一个由四个item组成的list,每个item分别为
1. 表格结构的html字符串list
2. 每个cell的坐标 (不包括cell里文字为空的)
3. 每个cell里的文字信息 (不包括cell里文字为空的)
4. 每个cell里的文字信息 (包括cell里文字为空的)
准备完成后使用如下命令进行评估,评估完成后会输出teds指标。
```python
python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --structure_model_dir=path/to/structure_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --structure_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
```
### 2.3 预测
先cd到PaddleOCR/ppstructure目录下
```python
python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --structure_model_dir=path/to/structure_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --structure_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
运行完成后,每张图片的excel表格会保存到table_output字段指定的目录下
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import cv2
import json
from tqdm import tqdm
from test.table.table_metric import TEDS
from test.table.predict_table import TableSystem
from test.utility import init_args
from ppocr.utils.logging import get_logger
logger = get_logger()
def parse_args():
parser = init_args()
parser.add_argument("--gt_path", type=str)
return parser.parse_args()
def main(gt_path, img_root, args):
teds = TEDS(n_jobs=16)
text_sys = TableSystem(args)
jsons_gt = json.load(open(gt_path)) # gt
pred_htmls = []
gt_htmls = []
for img_name in tqdm(jsons_gt):
# read image
img = cv2.imread(os.path.join(img_root,img_name))
pred_html = text_sys(img)
pred_htmls.append(pred_html)
gt_structures, gt_bboxes, gt_contents, contents_with_block = jsons_gt[img_name]
gt_html, gt = get_gt_html(gt_structures, contents_with_block)
gt_htmls.append(gt_html)
scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
logger.info('teds:', sum(scores) / len(scores))
def get_gt_html(gt_structures, contents_with_block):
end_html = []
td_index = 0
for tag in gt_structures:
if '</td>' in tag:
if contents_with_block[td_index] != []:
end_html.extend(contents_with_block[td_index])
end_html.append(tag)
td_index += 1
else:
end_html.append(tag)
return ''.join(end_html), end_html
if __name__ == '__main__':
args = parse_args()
main(args.gt_path,args.image_dir, args)
import json
def distance(box_1, box_2):
x1, y1, x2, y2 = box_1
x3, y3, x4, y4 = box_2
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
dis_2 = abs(x3 - x1) + abs(y3 - y1)
dis_3 = abs(x4- x2) + abs(y4 - y2)
return dis + min(dis_2, dis_3)
def compute_iou(rec1, rec2):
"""
computing IoU
:param rec1: (y0, x0, y1, x1), which reflects
(top, left, bottom, right)
:param rec2: (y0, x0, y1, x1)
:return: scala value of IoU
"""
# computing area of each rectangles
S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
# computing the sum_area
sum_area = S_rec1 + S_rec2
# find the each edge of intersect rectangle
left_line = max(rec1[1], rec2[1])
right_line = min(rec1[3], rec2[3])
top_line = max(rec1[0], rec2[0])
bottom_line = min(rec1[2], rec2[2])
# judge if there is an intersect
if left_line >= right_line or top_line >= bottom_line:
return 0.0
else:
intersect = (right_line - left_line) * (bottom_line - top_line)
return (intersect / (sum_area - intersect))*1.0
def matcher_merge(ocr_bboxes, pred_bboxes):
all_dis = []
ious = []
matched = {}
for i, gt_box in enumerate(ocr_bboxes):
distances = []
for j, pred_box in enumerate(pred_bboxes):
# compute l1 distence and IOU between two boxes
distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box)))
sorted_distances = distances.copy()
# select nearest cell
sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0]))
if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
else:
matched[distances.index(sorted_distances[0])].append(i)
return matched#, sum(ious) / len(ious)
def complex_num(pred_bboxes):
complex_nums = []
for bbox in pred_bboxes:
distances = []
temp_ious = []
for pred_bbox in pred_bboxes:
if bbox != pred_bbox:
distances.append(distance(bbox, pred_bbox))
temp_ious.append(compute_iou(bbox, pred_bbox))
complex_nums.append(temp_ious[distances.index(min(distances))])
return sum(complex_nums) / len(complex_nums)
def get_rows(pred_bboxes):
pre_bbox = pred_bboxes[0]
res = []
step = 0
for i in range(len(pred_bboxes)):
bbox = pred_bboxes[i]
if bbox[1] - pre_bbox[1] > 2 or bbox[0] - pre_bbox[0] < 0:
break
else:
res.append(bbox)
step += 1
for i in range(step):
pred_bboxes.pop(0)
return res, pred_bboxes
def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
ys_1 = []
ys_2 = []
for box in pred_bboxes:
ys_1.append(box[1])
ys_2.append(box[3])
min_y_1 = sum(ys_1) / len(ys_1)
min_y_2 = sum(ys_2) / len(ys_2)
re_boxes = []
for box in pred_bboxes:
box[1] = min_y_1
box[3] = min_y_2
re_boxes.append(box)
return re_boxes
def matcher_refine_row(gt_bboxes, pred_bboxes):
before_refine_pred_bboxes = pred_bboxes.copy()
pred_bboxes = []
while(len(before_refine_pred_bboxes) != 0):
row_bboxes, before_refine_pred_bboxes = get_rows(before_refine_pred_bboxes)
print(row_bboxes)
pred_bboxes.extend(refine_rows(row_bboxes))
all_dis = []
ious = []
matched = {}
for i, gt_box in enumerate(gt_bboxes):
distances = []
#temp_ious = []
for j, pred_box in enumerate(pred_bboxes):
distances.append(distance(gt_box, pred_box))
#temp_ious.append(compute_iou(gt_box, pred_box))
#all_dis.append(min(distances))
#ious.append(temp_ious[distances.index(min(distances))])
if distances.index(min(distances)) not in matched.keys():
matched[distances.index(min(distances))] = [i]
else:
matched[distances.index(min(distances))].append(i)
return matched#, sum(ious) / len(ious)
#先挑选出一行,再进行匹配
def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes):
gt_box_index = 0
delete_gt_bboxes = gt_bboxes.copy()
match_bboxes_ready = []
matched = {}
while(len(delete_gt_bboxes) != 0):
row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes)
row_bboxes = sorted(row_bboxes, key = lambda key: key[0])
if len(pred_bboxes_rows) > 0:
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
print(row_bboxes)
for i, gt_box in enumerate(row_bboxes):
#print(gt_box)
pred_distances = []
distances = []
for pred_bbox in pred_bboxes:
pred_distances.append(distance(gt_box, pred_bbox))
for j, pred_box in enumerate(match_bboxes_ready):
distances.append(distance(gt_box, pred_box))
index = pred_distances.index(min(distances))
#print('index', index)
if index not in matched.keys():
matched[index] = [gt_box_index]
else:
matched[index].append(gt_box_index)
gt_box_index += 1
return matched
def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
'''
gt_bboxes: 排序后
pred_bboxes:
'''
pre_bbox = gt_bboxes[0]
matched = {}
match_bboxes_ready = []
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
for i, gt_box in enumerate(gt_bboxes):
pred_distances = []
for pred_bbox in pred_bboxes:
pred_distances.append(distance(gt_box, pred_bbox))
distances = []
gap_pre = gt_box[1] - pre_bbox[1]
gap_pre_1 = gt_box[0] - pre_bbox[2]
#print(gap_pre, len(pred_bboxes_rows))
if (gap_pre_1 < 0 and len(pred_bboxes_rows) > 0):
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
if len(pred_bboxes_rows) == 1:
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) > 0:
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) == 0:
break
#print(match_bboxes_ready)
for j, pred_box in enumerate(match_bboxes_ready):
distances.append(distance(gt_box, pred_box))
index = pred_distances.index(min(distances))
#print(gt_box, index)
#match_bboxes_ready.pop(distances.index(min(distances)))
print(gt_box, match_bboxes_ready[distances.index(min(distances))])
if index not in matched.keys():
matched[index] = [i]
else:
matched[index].append(i)
pre_bbox = gt_box
return matched
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import numpy as np
import math
import time
import traceback
import paddle
import tools.infer.utility as utility
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from test.utility import parse_args
logger = get_logger()
class TableStructurer(object):
def __init__(self, args):
pre_process_list = [{
'ResizeTableImage': {
'max_len': args.structure_max_len
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
}
}, {
'PaddingTableImage': None
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image']
}
}]
postprocess_params = {
'name': 'TableLabelDecode',
"character_type": args.structure_char_type,
"character_dict_path": args.structure_char_dict_path,
}
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'structure', logger)
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
data = transform(data, self.preprocess_op)
img = data[0]
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
img = img.copy()
starttime = time.time()
self.input_tensor.copy_from_cpu(img)
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = {}
preds['structure_probs'] = outputs[1]
preds['loc_preds'] = outputs[0]
post_result = self.postprocess_op(preds)
structure_str_list = post_result['structure_str_list']
res_loc = post_result['res_loc']
imgh, imgw = ori_im.shape[0:2]
res_loc_final = []
for rno in range(len(res_loc[0])):
x0, y0, x1, y1 = res_loc[0][rno]
left = max(int(imgw * x0), 0)
top = max(int(imgh * y0), 0)
right = min(int(imgw * x1), imgw - 1)
bottom = min(int(imgh * y1), imgh - 1)
res_loc_final.append([left, top, right, bottom])
structure_str_list = structure_str_list[0][:-1]
structure_str_list = ['<html>', '<body>', '<table>'] + structure_str_list + ['</table>', '</body>', '</html>']
elapse = time.time() - starttime
return (structure_str_list, res_loc_final), elapse
def main(args):
image_file_list = get_image_file_list(args.image_dir)
table_structurer = TableStructurer(args)
count = 0
total_time = 0
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
structure_res, elapse = table_structurer(img)
logger.info("result: {}".format(structure_res))
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))
if __name__ == "__main__":
main(parse_args())
此差异已折叠。
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['TEDS']
from .table_metric import TEDS
\ No newline at end of file
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0):
"""
A parallel version of the map function with a progress bar.
Args:
array (array-like): An array to iterate over.
function (function): A python function to apply to the elements of array
n_jobs (int, default=16): The number of cores to use
use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
keyword arguments to function
front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
Useful for catching bugs
Returns:
[function(array[0]), function(array[1]), ...]
"""
# We run the first few iterations serially to catch bugs
if front_num > 0:
front = [function(**a) if use_kwargs else function(a)
for a in array[:front_num]]
else:
front = []
# If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
if n_jobs == 1:
return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
# Assemble the workers
with ProcessPoolExecutor(max_workers=n_jobs) as pool:
# Pass the elements of array into function
if use_kwargs:
futures = [pool.submit(function, **a) for a in array[front_num:]]
else:
futures = [pool.submit(function, a) for a in array[front_num:]]
kwargs = {
'total': len(futures),
'unit': 'it',
'unit_scale': True,
'leave': True
}
# Print out the progress as tasks complete
for f in tqdm(as_completed(futures), **kwargs):
pass
out = []
# Get the results from the futures.
for i, future in tqdm(enumerate(futures)):
try:
out.append(future.result())
except Exception as e:
out.append(e)
return front + out
此差异已折叠。
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -43,7 +43,7 @@ class TextDetector(object):
pre_process_list = [{
'DetResizeForTest': {
'limit_side_len': args.det_limit_side_len,
'limit_type': args.det_limit_type
'limit_type': args.det_limit_type,
}
}, {
'NormalizeImage': {
......
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册