diff --git a/MANIFEST.in b/MANIFEST.in index e16f157d6e9dd249d6c6a14ae54313759a6752c4..cd1c9636d4d23cc4d0f745403ec8ca407d1cc1a8 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,7 +1,7 @@ -include LICENSE.txt +include LICENSE include README.md -recursive-include ppocr/utils *.txt utility.py logging.py +recursive-include ppocr/utils *.txt utility.py logging.py network.py recursive-include ppocr/data/ *.py recursive-include ppocr/postprocess *.py recursive-include tools/infer *.py diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java index 1c83e2184fe55aedd5022da839ab294b6bbe475c..b4ea34e2a38f91f3ecb1001c6bff3b71496b8f91 100644 --- a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java @@ -465,8 +465,12 @@ public class MainActivity extends AppCompatActivity { } public void btn_load_model_click(View view) { - tvStatus.setText("STATUS: load model ......"); - loadModel(); + if (predictor.isLoaded()){ + tvStatus.setText("STATUS: model has been loaded"); + }else{ + tvStatus.setText("STATUS: load model ......"); + loadModel(); + } } public void btn_run_model_click(View view) { diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java index 1c294995c25b7eb3fa6ded17f41f193bddfc3886..b474d8886a10746b8ac181085c62481dfe7a4229 100644 --- a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java @@ -194,26 +194,25 @@ public class Predictor { "supported!"); return false; } - int[] channelStride = new int[]{width * height, width * height * 2}; - int p = scaleImage.getPixel(scaleImage.getWidth() - 1, scaleImage.getHeight() - 1); - for (int y = 0; y < height; y++) { - for (int x = 0; x < width; x++) { - int color = scaleImage.getPixel(x, y); - float[] rgb = new float[]{(float) red(color) / 255.0f, (float) green(color) / 255.0f, - (float) blue(color) / 255.0f}; - inputData[y * width + x] = (rgb[channelIdx[0]] - inputMean[0]) / inputStd[0]; - inputData[y * width + x + channelStride[0]] = (rgb[channelIdx[1]] - inputMean[1]) / inputStd[1]; - inputData[y * width + x + channelStride[1]] = (rgb[channelIdx[2]] - inputMean[2]) / inputStd[2]; - } + int[] channelStride = new int[]{width * height, width * height * 2}; + int[] pixels=new int[width*height]; + scaleImage.getPixels(pixels,0,scaleImage.getWidth(),0,0,scaleImage.getWidth(),scaleImage.getHeight()); + for (int i = 0; i < pixels.length; i++) { + int color = pixels[i]; + float[] rgb = new float[]{(float) red(color) / 255.0f, (float) green(color) / 255.0f, + (float) blue(color) / 255.0f}; + inputData[i] = (rgb[channelIdx[0]] - inputMean[0]) / inputStd[0]; + inputData[i + channelStride[0]] = (rgb[channelIdx[1]] - inputMean[1]) / inputStd[1]; + inputData[i+ channelStride[1]] = (rgb[channelIdx[2]] - inputMean[2]) / inputStd[2]; } } else if (channels == 1) { - for (int y = 0; y < height; y++) { - for (int x = 0; x < width; x++) { - int color = inputImage.getPixel(x, y); - float gray = (float) (red(color) + green(color) + blue(color)) / 3.0f / 255.0f; - inputData[y * width + x] = (gray - inputMean[0]) / inputStd[0]; - } + int[] pixels=new int[width*height]; + scaleImage.getPixels(pixels,0,scaleImage.getWidth(),0,0,scaleImage.getWidth(),scaleImage.getHeight()); + for (int i = 0; i < pixels.length; i++) { + int color = pixels[i]; + float gray = (float) (red(color) + green(color) + blue(color)) / 3.0f / 255.0f; + inputData[i] = (gray - inputMean[0]) / inputStd[0]; } } else { Log.i(TAG, "Unsupported channel size " + Integer.toString(channels) + ", only channel 1 and 3 is " + diff --git a/doc/doc_ch/config.md b/doc/doc_ch/config.md index 3f8f7ff1674d63e721d7ad2ced31bf771b0183eb..74cd238134d1999a6fbd96d0ad053d0304231a0b 100644 --- a/doc/doc_ch/config.md +++ b/doc/doc_ch/config.md @@ -111,9 +111,9 @@ | 字段 | 用途 | 默认值 | 备注 | | :---------------------: | :---------------------: | :--------------: | :--------------------: | | **dataset** | 每次迭代返回一个样本 | - | - | -| name | dataset类名 | SimpleDataSet | 目前支持`SimpleDataSet`和`LMDBDateSet` | +| name | dataset类名 | SimpleDataSet | 目前支持`SimpleDataSet`和`LMDBDataSet` | | data_dir | 数据集图片存放路径 | ./train_data | \ | -| label_file_list | 数据标签路径 | ["./train_data/train_list.txt"] | dataset为LMDBDateSet时不需要此参数 | +| label_file_list | 数据标签路径 | ["./train_data/train_list.txt"] | dataset为LMDBDataSet时不需要此参数 | | ratio_list | 数据集的比例 | [1.0] | 若label_file_list中有两个train_list,且ratio_list为[0.4,0.6],则从train_list1中采样40%,从train_list2中采样60%组合整个dataset | | transforms | 对图片和标签进行变换的方法列表 | [DecodeImage,CTCLabelEncode,RecResizeImg,KeepKeys] | 见[ppocr/data/imaug](../../ppocr/data/imaug) | | **loader** | dataloader相关 | - | | diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md index 8a7c341cf24738b8af8c974a6da41bcb1b51ce48..0f860065bef9eff8f90c18f120e43dcf0c2a47aa 100644 --- a/doc/doc_ch/recognition.md +++ b/doc/doc_ch/recognition.md @@ -243,7 +243,7 @@ Optimizer: Train: dataset: - # 数据集格式,支持LMDBDateSet以及SimpleDataSet + # 数据集格式,支持LMDBDataSet以及SimpleDataSet name: SimpleDataSet # 数据集路径 data_dir: ./train_data/ @@ -263,7 +263,7 @@ Train: Eval: dataset: - # 数据集格式,支持LMDBDateSet以及SimpleDataSet + # 数据集格式,支持LMDBDataSet以及SimpleDataSet name: SimpleDataSet # 数据集路径 data_dir: ./train_data @@ -393,7 +393,7 @@ Global: Train: dataset: - # 数据集格式,支持LMDBDateSet以及SimpleDataSet + # 数据集格式,支持LMDBDataSet以及SimpleDataSet name: SimpleDataSet # 数据集路径 data_dir: ./train_data/ @@ -403,7 +403,7 @@ Train: Eval: dataset: - # 数据集格式,支持LMDBDateSet以及SimpleDataSet + # 数据集格式,支持LMDBDataSet以及SimpleDataSet name: SimpleDataSet # 数据集路径 data_dir: ./train_data diff --git a/doc/doc_en/config_en.md b/doc/doc_en/config_en.md index 28ebb6e830369447395c661cbcc76aaf067a91d9..5e5847c4b298553b2d376b90196b61b7e0286efe 100644 --- a/doc/doc_en/config_en.md +++ b/doc/doc_en/config_en.md @@ -110,9 +110,9 @@ In ppocr, the network is divided into four stages: Transform, Backbone, Neck and | Parameter | Use | Defaults | Note | | :---------------------: | :---------------------: | :--------------: | :--------------------: | | **dataset** | Return one sample per iteration | - | - | -| name | dataset class name | SimpleDataSet | Currently support`SimpleDataSet`,`LMDBDateSet` | +| name | dataset class name | SimpleDataSet | Currently support`SimpleDataSet`,`LMDBDataSet` | | data_dir | Image folder path | ./train_data | \ | -| label_file_list | Groundtruth file path | ["./train_data/train_list.txt"] | This parameter is not required when dataset is LMDBDateSet | +| label_file_list | Groundtruth file path | ["./train_data/train_list.txt"] | This parameter is not required when dataset is LMDBDataSet | | ratio_list | Ratio of data set | [1.0] | If there are two train_lists in label_file_list and ratio_list is [0.4,0.6], 40% will be sampled from train_list1, and 60% will be sampled from train_list2 to combine the entire dataset | | transforms | List of methods to transform images and labels | [DecodeImage,CTCLabelEncode,RecResizeImg,KeepKeys] | see[ppocr/data/imaug](../../ppocr/data/imaug) | | **loader** | dataloader related | - | | diff --git a/doc/doc_en/recognition_en.md b/doc/doc_en/recognition_en.md index 0b3db6a235bdbfeb930d6cf3f7d086829fd32c43..e23166e0caef4f6a246502fa12101f86d61e4eac 100644 --- a/doc/doc_en/recognition_en.md +++ b/doc/doc_en/recognition_en.md @@ -237,7 +237,7 @@ Optimizer: Train: dataset: - # Type of dataset,we support LMDBDateSet and SimpleDataSet + # Type of dataset,we support LMDBDataSet and SimpleDataSet name: SimpleDataSet # Path of dataset data_dir: ./train_data/ @@ -257,7 +257,7 @@ Train: Eval: dataset: - # Type of dataset,we support LMDBDateSet and SimpleDataSet + # Type of dataset,we support LMDBDataSet and SimpleDataSet name: SimpleDataSet # Path of dataset data_dir: ./train_data @@ -394,7 +394,7 @@ Global: Train: dataset: - # Type of dataset,we support LMDBDateSet and SimpleDataSet + # Type of dataset,we support LMDBDataSet and SimpleDataSet name: SimpleDataSet # Path of dataset data_dir: ./train_data/ @@ -404,7 +404,7 @@ Train: Eval: dataset: - # Type of dataset,we support LMDBDateSet and SimpleDataSet + # Type of dataset,we support LMDBDataSet and SimpleDataSet name: SimpleDataSet # Path of dataset data_dir: ./train_data diff --git a/doc/joinus.PNG b/doc/joinus.PNG index 6e299a2ebe0eb52aa799ba9fa924bd685cd248de..4a274e631d8516789fca47e2db66bc0ac2d0f223 100644 Binary files a/doc/joinus.PNG and b/doc/joinus.PNG differ diff --git a/doc/table/1.png b/doc/table/1.png new file mode 100644 index 0000000000000000000000000000000000000000..47df618ab1bef431a5dd94418c01be16b09d31aa Binary files /dev/null and b/doc/table/1.png differ diff --git a/paddleocr.py b/paddleocr.py index 1e4d94ff4e72da951e1ffb92edb50715482581ae..48c8c9c6523dc3f813189477e641f0e51b740885 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -19,17 +19,16 @@ __dir__ = os.path.dirname(__file__) sys.path.append(os.path.join(__dir__, '')) import cv2 +import logging import numpy as np from pathlib import Path -import tarfile -import requests -from tqdm import tqdm from tools.infer import predict_system from ppocr.utils.logging import get_logger logger = get_logger() from ppocr.utils.utility import check_and_read_gif, get_image_file_list +from ppocr.utils.network import maybe_download, download_with_progressbar from tools.infer.utility import draw_ocr, init_args, str2bool __all__ = ['PaddleOCR'] @@ -37,84 +36,84 @@ __all__ = ['PaddleOCR'] model_urls = { 'det': { 'ch': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar', 'en': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar' + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar' }, 'rec': { 'ch': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/ppocr_keys_v1.txt' }, 'en': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/en_dict.txt' }, 'french': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/french_dict.txt' }, 'german': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/german_dict.txt' }, 'korean': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/korean_dict.txt' }, 'japan': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/japan_dict.txt' }, 'chinese_cht': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt' }, 'ta': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/ta_dict.txt' }, 'te': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/te_dict.txt' }, 'ka': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/ka_dict.txt' }, 'latin': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/latin_dict.txt' }, 'arabic': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/arabic_dict.txt' }, 'cyrillic': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/cyrillic_dict.txt' }, 'devanagari': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/devanagari_dict.txt' } }, 'cls': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar' + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar' } SUPPORT_DET_MODEL = ['DB'] @@ -123,50 +122,6 @@ SUPPORT_REC_MODEL = ['CRNN'] BASE_DIR = os.path.expanduser("~/.paddleocr/") -def download_with_progressbar(url, save_path): - response = requests.get(url, stream=True) - total_size_in_bytes = int(response.headers.get('content-length', 0)) - block_size = 1024 # 1 Kibibyte - progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) - with open(save_path, 'wb') as file: - for data in response.iter_content(block_size): - progress_bar.update(len(data)) - file.write(data) - progress_bar.close() - if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes: - logger.error("Something went wrong while downloading models") - sys.exit(0) - - -def maybe_download(model_storage_directory, url): - # using custom model - tar_file_name_list = [ - 'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel' - ] - if not os.path.exists( - os.path.join(model_storage_directory, 'inference.pdiparams') - ) or not os.path.exists( - os.path.join(model_storage_directory, 'inference.pdmodel')): - tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) - print('download {} to {}'.format(url, tmp_path)) - os.makedirs(model_storage_directory, exist_ok=True) - download_with_progressbar(url, tmp_path) - with tarfile.open(tmp_path, 'r') as tarObj: - for member in tarObj.getmembers(): - filename = None - for tar_file_name in tar_file_name_list: - if tar_file_name in member.name: - filename = tar_file_name - if filename is None: - continue - file = tarObj.extractfile(member) - with open( - os.path.join(model_storage_directory, filename), - 'wb') as f: - f.write(file.read()) - os.remove(tmp_path) - - def parse_args(mMain=True): import argparse parser = init_args() @@ -194,10 +149,12 @@ class PaddleOCR(predict_system.TextSystem): args: **kwargs: other params show in paddleocr --help """ - postprocess_params = parse_args(mMain=False) - postprocess_params.__dict__.update(**kwargs) - self.use_angle_cls = postprocess_params.use_angle_cls - lang = postprocess_params.lang + params = parse_args(mMain=False) + params.__dict__.update(**kwargs) + if params.show_log: + logger.setLevel(logging.DEBUG) + self.use_angle_cls = params.use_angle_cls + lang = params.lang latin_lang = [ 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', @@ -223,46 +180,46 @@ class PaddleOCR(predict_system.TextSystem): lang = "devanagari" assert lang in model_urls[ 'rec'], 'param lang must in {}, but got {}'.format( - model_urls['rec'].keys(), lang) + model_urls['rec'].keys(), lang) if lang == "ch": det_lang = "ch" else: det_lang = "en" use_inner_dict = False - if postprocess_params.rec_char_dict_path is None: + if params.rec_char_dict_path is None: use_inner_dict = True - postprocess_params.rec_char_dict_path = model_urls['rec'][lang][ + params.rec_char_dict_path = model_urls['rec'][lang][ 'dict_path'] # init model dir - if postprocess_params.det_model_dir is None: - postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION, + if params.det_model_dir is None: + params.det_model_dir = os.path.join(BASE_DIR, VERSION, 'det', det_lang) - if postprocess_params.rec_model_dir is None: - postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION, + if params.rec_model_dir is None: + params.rec_model_dir = os.path.join(BASE_DIR, VERSION, 'rec', lang) - if postprocess_params.cls_model_dir is None: - postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls') - print(postprocess_params) + if params.cls_model_dir is None: + params.cls_model_dir = os.path.join(BASE_DIR, 'cls') # download model - maybe_download(postprocess_params.det_model_dir, + maybe_download(params.det_model_dir, model_urls['det'][det_lang]) - maybe_download(postprocess_params.rec_model_dir, + maybe_download(params.rec_model_dir, model_urls['rec'][lang]['url']) - maybe_download(postprocess_params.cls_model_dir, model_urls['cls']) + maybe_download(params.cls_model_dir, model_urls['cls']) - if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL: + if params.det_algorithm not in SUPPORT_DET_MODEL: logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) sys.exit(0) - if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL: + if params.rec_algorithm not in SUPPORT_REC_MODEL: logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL)) sys.exit(0) if use_inner_dict: - postprocess_params.rec_char_dict_path = str( - Path(__file__).parent / postprocess_params.rec_char_dict_path) + params.rec_char_dict_path = str( + Path(__file__).parent / params.rec_char_dict_path) + print(params) # init det_model and rec_model - super().__init__(postprocess_params) + super().__init__(params) def ocr(self, img, det=True, rec=True, cls=True): """ diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index 9c48b09647527cf718113ea1b5df152ff7befa04..2535b4420c503f2e9e9cc5a677ef70c4dd9c36be 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -81,7 +81,7 @@ class NormalizeImage(object): assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage" data['image'] = ( - img.astype('float32') * self.scale - self.mean) / self.std + img.astype('float32') * self.scale - self.mean) / self.std return data @@ -163,7 +163,7 @@ class DetResizeForTest(object): img, (ratio_h, ratio_w) """ limit_side_len = self.limit_side_len - h, w, _ = img.shape + h, w, c = img.shape # limit the max side if self.limit_type == 'max': @@ -174,7 +174,7 @@ class DetResizeForTest(object): ratio = float(limit_side_len) / w else: ratio = 1. - else: + elif self.limit_type == 'min': if min(h, w) < limit_side_len: if h < w: ratio = float(limit_side_len) / h @@ -182,6 +182,10 @@ class DetResizeForTest(object): ratio = float(limit_side_len) / w else: ratio = 1. + elif self.limit_type == 'resize_long': + ratio = float(limit_side_len) / max(h,w) + else: + raise Exception('not support limit type, image ') resize_h = int(h * ratio) resize_w = int(w * ratio) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index d353391c9af2b85bd01ba659f541fa1791461f68..85ce580f95b13539c6aeea32b188bfd3b435d140 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -44,16 +44,16 @@ class BaseRecLabelDecode(object): self.character_str = string.printable[:-6] dict_character = list(self.character_str) elif character_type in support_character_type: - self.character_str = "" + self.character_str = [] assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format( character_type) with open(character_dict_path, "rb") as fin: lines = fin.readlines() for line in lines: line = line.decode('utf-8').strip("\n").strip("\r\n") - self.character_str += line + self.character_str.append(line) if use_space_char: - self.character_str += " " + self.character_str.append(" ") dict_character = list(self.character_str) else: @@ -288,3 +288,156 @@ class SRNLabelDecode(BaseRecLabelDecode): assert False, "unsupport type %s in get_beg_end_flag_idx" \ % beg_or_end return idx + + +class TableLabelDecode(object): + """ """ + + def __init__(self, + max_text_length, + max_elem_length, + max_cell_num, + character_dict_path, + **kwargs): + self.max_text_length = max_text_length + self.max_elem_length = max_elem_length + self.max_cell_num = max_cell_num + list_character, list_elem = self.load_char_elem_dict(character_dict_path) + list_character = self.add_special_char(list_character) + list_elem = self.add_special_char(list_elem) + self.dict_character = {} + self.dict_idx_character = {} + for i, char in enumerate(list_character): + self.dict_idx_character[i] = char + self.dict_character[char] = i + self.dict_elem = {} + self.dict_idx_elem = {} + for i, elem in enumerate(list_elem): + self.dict_idx_elem[i] = elem + self.dict_elem[elem] = i + + def load_char_elem_dict(self, character_dict_path): + list_character = [] + list_elem = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + substr = lines[0].decode('utf-8').strip("\n").split("\t") + character_num = int(substr[0]) + elem_num = int(substr[1]) + for cno in range(1, 1 + character_num): + character = lines[cno].decode('utf-8').strip("\n") + list_character.append(character) + for eno in range(1 + character_num, 1 + character_num + elem_num): + elem = lines[eno].decode('utf-8').strip("\n") + list_elem.append(elem) + return list_character, list_elem + + def add_special_char(self, list_character): + self.beg_str = "sos" + self.end_str = "eos" + list_character = [self.beg_str] + list_character + [self.end_str] + return list_character + + def get_sp_tokens(self): + char_beg_idx = self.get_beg_end_flag_idx('beg', 'char') + char_end_idx = self.get_beg_end_flag_idx('end', 'char') + elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem') + elem_end_idx = self.get_beg_end_flag_idx('end', 'elem') + elem_char_idx1 = self.dict_elem[''] + elem_char_idx2 = self.dict_elem['', ' 0 and tmp_elem_idx == end_idx: + break + if tmp_elem_idx in ignored_tokens: + continue + + char_list.append(current_dict[tmp_elem_idx]) + elem_pos_list.append(idx) + score_list.append(structure_probs[batch_idx, idx]) + elem_idx_list.append(tmp_elem_idx) + result_list.append(char_list) + result_pos_list.append(elem_pos_list) + result_score_list.append(score_list) + result_elem_idx_list.append(elem_idx_list) + return result_list, result_pos_list, result_score_list, result_elem_idx_list + + def get_ignored_tokens(self, char_or_elem): + beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem) + end_idx = self.get_beg_end_flag_idx("end", char_or_elem) + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end, char_or_elem): + if char_or_elem == "char": + if beg_or_end == "beg": + idx = self.dict_character[self.beg_str] + elif beg_or_end == "end": + idx = self.dict_character[self.end_str] + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \ + % beg_or_end + elif char_or_elem == "elem": + if beg_or_end == "beg": + idx = self.dict_elem[self.beg_str] + elif beg_or_end == "end": + idx = self.dict_elem[self.end_str] + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \ + % beg_or_end + else: + assert False, "Unsupport type %s in char_or_elem" \ + % char_or_elem + return idx diff --git a/ppocr/utils/dict/table_dict.txt b/ppocr/utils/dict/table_dict.txt new file mode 100644 index 0000000000000000000000000000000000000000..2ef028c786cbce6d1e25856c62986d757b31f93b --- /dev/null +++ b/ppocr/utils/dict/table_dict.txt @@ -0,0 +1,277 @@ +← + +☆ +─ +α + + +⋅ +$ +ω +ψ +χ +( +υ +≥ +σ +, +ρ +ε +0 +■ +4 +8 +✗ +b +< +✓ +Ψ +Ω +€ +D +3 +Π +H +║ + +L +Φ +Χ +θ +P +κ +λ +μ +T +ξ +X +β +γ +δ +\ +ζ +η +` +d + +h +f +l +Θ +p +√ +t + +x +Β +Γ +Δ +| +ǂ +ɛ +j +̧ +➢ +⁡ +̌ +′ +« +△ +▲ +# + +' +Ι ++ +¶ +/ +▼ +⇑ +□ +· +7 +▪ +; +? +➔ +∩ +C +÷ +G +⇒ +K + +O +S +С +W +Α +[ +○ +_ +● +‡ +c +z +g + +o + +〈 +〉 +s +⩽ +w +φ +ʹ +{ +» +∣ +̆ +e +ˆ +∈ +τ +◆ +ι +∅ +∆ +∙ +∘ +Ø +ß +✔ +∞ +∑ +− +× +◊ +∗ +∖ +˃ +˂ +∫ +" +i +& +π +↔ +* +∥ +æ +∧ +. +⁄ +ø +Q +∼ +6 +⁎ +: +★ +> +a +B +≈ +F +J +̄ +N +♯ +R +V + +― +Z +♣ +^ +¤ +¥ +§ + +¢ +£ +≦ +­ +≤ +‖ +Λ +© +n +↓ +→ +↑ +r +° +± +v + +♂ +k +♀ +~ +ᅟ +̇ +@ +” +♦ +ł +® +⊕ +„ +! + +% +⇓ +) +- +1 +5 +9 += +А +A +‰ +⋆ +Σ +E +◦ +I +※ +M +m +̨ +⩾ +† + +• +U +Y +
 +] +̸ +2 +‐ +– +‒ +̂ +— +̀ +́ +’ +‘ +⋮ +⋯ +̊ +“ +̈ +≧ +q +u +ı +y + +​ +̃ +} +ν diff --git a/ppocr/utils/dict/table_structure_dict.txt b/ppocr/utils/dict/table_structure_dict.txt new file mode 100644 index 0000000000000000000000000000000000000000..9c4531e5f3b8c498e70d3c2ea0471e5e746a2c30 --- /dev/null +++ b/ppocr/utils/dict/table_structure_dict.txt @@ -0,0 +1,2759 @@ +277 28 1267 1186 + +V +a +r +i +b +l +e + +H +z +d + +t +o +9 +5 +% +C +I + +p + +v +u +* +A +g +( +m +n +) +0 +. +7 +1 +6 +≤ +> +8 +3 +– +2 +G +4 +M +F +T +y +f +s +L +w +c +U +h +D +S +Q +R +x +P +- +E +O +/ +k +, ++ +N +K +q +′ +[ +] +< +≥ + +− + +μ +± +J +j +W +_ +Δ +B +“ +: +Y +α +λ +; + + +? +∼ += +° +# +̊ +̈ +̂ +’ +Z +X +∗ +— +β +' +† +~ +@ +" +γ +↓ +↑ +& +‡ +χ +” +σ +§ +| +¶ +‐ +× +$ +→ +√ +✓ +‘ +\ +∞ +π +• +® +^ +∆ +≧ + + +́ +♀ +♂ +‒ +⁎ +▲ +· +£ +φ +Ψ +ß +△ +☆ +▪ +η +€ +∧ +̃ +Φ +ρ +̄ +δ +‰ +̧ +Ω +♦ +{ +} +̀ +∑ +∫ +ø +κ +ε +¥ +※ +` +ω +Σ +➔ +‖ +Β +̸ +
 +─ +● +⩾ +Χ +Α +⋅ +◆ +★ +■ +ψ +ǂ +□ +ζ +! +Γ +↔ +θ +⁄ +〈 +〉 +― +υ +τ +⋆ +Ø +© +∥ +С +˂ +➢ +ɛ +⁡ +✗ +← +○ +¢ +⩽ +∖ +˃ +­ +≈ +Π +̌ +≦ +∅ +ᅟ + + +∣ +¤ +♯ +̆ +ξ +÷ +▼ + +ι +ν +║ + + +◦ +​ +◊ +∙ +« +» +ł +ı +Θ +∈ +„ +∘ +✔ +̇ +æ +ʹ +ˆ +♣ +⇓ +∩ +⊕ +⇒ +⇑ +̨ +Ι +Λ +⋯ +А +⋮ + + + + + + + + + + colspan="2" + colspan="3" + rowspan="2" + colspan="4" + colspan="6" + rowspan="3" + colspan="9" + colspan="10" + colspan="7" + rowspan="4" + rowspan="5" + rowspan="9" + colspan="8" + rowspan="8" + rowspan="6" + rowspan="7" + rowspan="10" +0 2924682 +1 3405345 +2 2363468 +3 2709165 +4 4078680 +5 3250792 +6 1923159 +7 1617890 +8 1450532 +9 1717624 +10 1477550 +11 1489223 +12 915528 +13 819193 +14 593660 +15 518924 +16 682065 +17 494584 +18 400591 +19 396421 +20 340994 +21 280688 +22 250328 +23 226786 +24 199927 +25 182707 +26 164629 +27 141613 +28 127554 +29 116286 +30 107682 +31 96367 +32 88002 +33 79234 +34 72186 +35 65921 +36 60374 +37 55976 +38 52166 +39 47414 +40 44932 +41 41279 +42 38232 +43 35463 +44 33703 +45 30557 +46 29639 +47 27000 +48 25447 +49 23186 +50 22093 +51 20412 +52 19844 +53 18261 +54 17561 +55 16499 +56 15597 +57 14558 +58 14372 +59 13445 +60 13514 +61 12058 +62 11145 +63 10767 +64 10370 +65 9630 +66 9337 +67 8881 +68 8727 +69 8060 +70 7994 +71 7740 +72 7189 +73 6729 +74 6749 +75 6548 +76 6321 +77 5957 +78 5740 +79 5407 +80 5370 +81 5035 +82 4921 +83 4656 +84 4600 +85 4519 +86 4277 +87 4023 +88 3939 +89 3910 +90 3861 +91 3560 +92 3483 +93 3406 +94 3346 +95 3229 +96 3122 +97 3086 +98 3001 +99 2884 +100 2822 +101 2677 +102 2670 +103 2610 +104 2452 +105 2446 +106 2400 +107 2300 +108 2316 +109 2196 +110 2089 +111 2083 +112 2041 +113 1881 +114 1838 +115 1896 +116 1795 +117 1786 +118 1743 +119 1765 +120 1750 +121 1683 +122 1563 +123 1499 +124 1513 +125 1462 +126 1388 +127 1441 +128 1417 +129 1392 +130 1306 +131 1321 +132 1274 +133 1294 +134 1240 +135 1126 +136 1157 +137 1130 +138 1084 +139 1130 +140 1083 +141 1040 +142 980 +143 1031 +144 974 +145 980 +146 932 +147 898 +148 960 +149 907 +150 852 +151 912 +152 859 +153 847 +154 876 +155 792 +156 791 +157 765 +158 788 +159 787 +160 744 +161 673 +162 683 +163 697 +164 666 +165 680 +166 632 +167 677 +168 657 +169 618 +170 587 +171 585 +172 567 +173 549 +174 562 +175 548 +176 542 +177 539 +178 542 +179 549 +180 547 +181 526 +182 525 +183 514 +184 512 +185 505 +186 515 +187 467 +188 475 +189 458 +190 435 +191 443 +192 427 +193 424 +194 404 +195 389 +196 429 +197 404 +198 386 +199 351 +200 388 +201 408 +202 361 +203 346 +204 324 +205 361 +206 363 +207 364 +208 323 +209 336 +210 342 +211 315 +212 325 +213 328 +214 314 +215 327 +216 320 +217 300 +218 295 +219 315 +220 310 +221 295 +222 275 +223 248 +224 274 +225 232 +226 293 +227 259 +228 286 +229 263 +230 242 +231 214 +232 261 +233 231 +234 211 +235 250 +236 233 +237 206 +238 224 +239 210 +240 233 +241 223 +242 216 +243 222 +244 207 +245 212 +246 196 +247 205 +248 201 +249 202 +250 211 +251 201 +252 215 +253 179 +254 163 +255 179 +256 191 +257 188 +258 196 +259 150 +260 154 +261 176 +262 211 +263 166 +264 171 +265 165 +266 149 +267 182 +268 159 +269 161 +270 164 +271 161 +272 141 +273 151 +274 127 +275 129 +276 142 +277 158 +278 148 +279 135 +280 127 +281 134 +282 138 +283 131 +284 126 +285 125 +286 130 +287 126 +288 135 +289 125 +290 135 +291 131 +292 95 +293 135 +294 106 +295 117 +296 136 +297 128 +298 128 +299 118 +300 109 +301 112 +302 117 +303 108 +304 120 +305 100 +306 95 +307 108 +308 112 +309 77 +310 120 +311 104 +312 109 +313 89 +314 98 +315 82 +316 98 +317 93 +318 77 +319 93 +320 77 +321 98 +322 93 +323 86 +324 89 +325 73 +326 70 +327 71 +328 77 +329 87 +330 77 +331 93 +332 100 +333 83 +334 72 +335 74 +336 69 +337 77 +338 68 +339 78 +340 90 +341 98 +342 75 +343 80 +344 63 +345 71 +346 83 +347 66 +348 71 +349 70 +350 62 +351 62 +352 59 +353 63 +354 62 +355 52 +356 64 +357 64 +358 56 +359 49 +360 57 +361 63 +362 60 +363 68 +364 62 +365 55 +366 54 +367 40 +368 75 +369 70 +370 53 +371 58 +372 57 +373 55 +374 69 +375 57 +376 53 +377 43 +378 45 +379 47 +380 56 +381 51 +382 59 +383 51 +384 43 +385 34 +386 57 +387 49 +388 39 +389 46 +390 48 +391 43 +392 40 +393 54 +394 50 +395 41 +396 43 +397 33 +398 27 +399 49 +400 44 +401 44 +402 38 +403 30 +404 32 +405 37 +406 39 +407 42 +408 53 +409 39 +410 34 +411 31 +412 32 +413 52 +414 27 +415 41 +416 34 +417 36 +418 50 +419 35 +420 32 +421 33 +422 45 +423 35 +424 40 +425 29 +426 41 +427 40 +428 39 +429 32 +430 31 +431 34 +432 29 +433 27 +434 26 +435 22 +436 34 +437 28 +438 30 +439 38 +440 35 +441 36 +442 36 +443 27 +444 24 +445 33 +446 31 +447 25 +448 33 +449 27 +450 32 +451 46 +452 31 +453 35 +454 35 +455 34 +456 26 +457 21 +458 25 +459 26 +460 24 +461 27 +462 33 +463 30 +464 35 +465 21 +466 32 +467 19 +468 27 +469 16 +470 28 +471 26 +472 27 +473 26 +474 25 +475 25 +476 27 +477 20 +478 28 +479 22 +480 23 +481 16 +482 25 +483 27 +484 19 +485 23 +486 19 +487 15 +488 15 +489 23 +490 24 +491 19 +492 20 +493 18 +494 17 +495 30 +496 28 +497 20 +498 29 +499 17 +500 19 +501 21 +502 15 +503 24 +504 15 +505 19 +506 25 +507 16 +508 23 +509 26 +510 21 +511 15 +512 12 +513 16 +514 18 +515 24 +516 26 +517 18 +518 8 +519 25 +520 14 +521 8 +522 24 +523 20 +524 18 +525 15 +526 13 +527 17 +528 18 +529 22 +530 21 +531 9 +532 16 +533 17 +534 13 +535 17 +536 15 +537 13 +538 20 +539 13 +540 19 +541 29 +542 10 +543 8 +544 18 +545 13 +546 9 +547 18 +548 10 +549 18 +550 18 +551 9 +552 9 +553 15 +554 13 +555 15 +556 14 +557 14 +558 18 +559 8 +560 13 +561 9 +562 7 +563 12 +564 6 +565 9 +566 9 +567 18 +568 9 +569 10 +570 13 +571 14 +572 13 +573 21 +574 8 +575 16 +576 12 +577 9 +578 16 +579 17 +580 22 +581 6 +582 14 +583 13 +584 15 +585 11 +586 13 +587 5 +588 12 +589 13 +590 15 +591 13 +592 15 +593 12 +594 7 +595 18 +596 12 +597 13 +598 13 +599 13 +600 12 +601 12 +602 10 +603 11 +604 6 +605 6 +606 2 +607 9 +608 8 +609 12 +610 9 +611 12 +612 13 +613 12 +614 14 +615 9 +616 8 +617 9 +618 14 +619 13 +620 12 +621 6 +622 8 +623 8 +624 8 +625 12 +626 8 +627 7 +628 5 +629 8 +630 12 +631 6 +632 10 +633 10 +634 7 +635 8 +636 9 +637 6 +638 9 +639 4 +640 12 +641 4 +642 3 +643 11 +644 10 +645 6 +646 12 +647 12 +648 4 +649 4 +650 9 +651 8 +652 6 +653 5 +654 14 +655 10 +656 11 +657 8 +658 5 +659 5 +660 9 +661 13 +662 4 +663 5 +664 9 +665 11 +666 12 +667 7 +668 13 +669 2 +670 1 +671 7 +672 7 +673 7 +674 10 +675 9 +676 6 +677 5 +678 7 +679 6 +680 3 +681 3 +682 4 +683 9 +684 8 +685 5 +686 3 +687 11 +688 9 +689 2 +690 6 +691 5 +692 9 +693 5 +694 6 +695 5 +696 9 +697 8 +698 3 +699 7 +700 5 +701 9 +702 8 +703 7 +704 2 +705 3 +706 7 +707 6 +708 6 +709 10 +710 2 +711 10 +712 6 +713 7 +714 5 +715 6 +716 4 +717 6 +718 8 +719 4 +720 6 +721 7 +722 5 +723 7 +724 3 +725 10 +726 10 +727 3 +728 7 +729 7 +730 5 +731 2 +732 1 +733 5 +734 1 +735 5 +736 6 +737 2 +738 2 +739 3 +740 7 +741 2 +742 7 +743 4 +744 5 +745 4 +746 5 +747 3 +748 1 +749 4 +750 4 +751 2 +752 4 +753 6 +754 6 +755 6 +756 3 +757 2 +758 5 +759 5 +760 3 +761 4 +762 2 +763 1 +764 8 +765 3 +766 4 +767 3 +768 1 +769 5 +770 3 +771 3 +772 4 +773 4 +774 1 +775 3 +776 2 +777 2 +778 3 +779 3 +780 1 +781 4 +782 3 +783 4 +784 6 +785 3 +786 5 +787 4 +788 2 +789 4 +790 5 +791 4 +792 6 +794 4 +795 1 +796 1 +797 4 +798 2 +799 3 +800 3 +801 1 +802 5 +803 5 +804 3 +805 3 +806 3 +807 4 +808 4 +809 2 +811 5 +812 4 +813 6 +814 3 +815 2 +816 2 +817 3 +818 5 +819 3 +820 1 +821 1 +822 4 +823 3 +824 4 +825 8 +826 3 +827 5 +828 5 +829 3 +830 6 +831 3 +832 4 +833 8 +834 5 +835 3 +836 3 +837 2 +838 4 +839 2 +840 1 +841 3 +842 2 +843 1 +844 3 +846 4 +847 4 +848 3 +849 3 +850 2 +851 3 +853 1 +854 4 +855 4 +856 2 +857 4 +858 1 +859 2 +860 5 +861 1 +862 1 +863 4 +864 2 +865 2 +867 5 +868 1 +869 4 +870 1 +871 1 +872 1 +873 2 +875 5 +876 3 +877 1 +878 3 +879 3 +880 3 +881 2 +882 1 +883 6 +884 2 +885 2 +886 1 +887 1 +888 3 +889 2 +890 2 +891 3 +892 1 +893 3 +894 1 +895 5 +896 1 +897 3 +899 2 +900 2 +902 1 +903 2 +904 4 +905 4 +906 3 +907 1 +908 1 +909 2 +910 5 +911 2 +912 3 +914 1 +915 1 +916 2 +918 2 +919 2 +920 4 +921 4 +922 1 +923 1 +924 4 +925 5 +926 1 +928 2 +929 1 +930 1 +931 1 +932 1 +933 1 +934 2 +935 1 +936 1 +937 1 +938 2 +939 1 +941 1 +942 4 +944 2 +945 2 +946 2 +947 1 +948 1 +950 1 +951 2 +953 1 +954 2 +955 1 +956 1 +957 2 +958 1 +960 3 +962 4 +963 1 +964 1 +965 3 +966 2 +967 2 +968 1 +969 3 +970 3 +972 1 +974 4 +975 3 +976 3 +977 2 +979 2 +980 1 +981 1 +983 5 +984 1 +985 3 +986 1 +987 2 +988 4 +989 2 +991 2 +992 2 +993 1 +994 1 +996 2 +997 2 +998 1 +999 3 +1000 2 +1001 1 +1002 3 +1003 3 +1004 2 +1005 3 +1006 1 +1007 2 +1009 1 +1011 1 +1013 3 +1014 1 +1016 2 +1017 1 +1018 1 +1019 1 +1020 4 +1021 1 +1022 2 +1025 1 +1026 1 +1027 2 +1028 1 +1030 1 +1031 2 +1032 4 +1034 3 +1035 2 +1036 1 +1038 1 +1039 1 +1040 1 +1041 1 +1042 2 +1043 1 +1044 2 +1045 4 +1048 1 +1050 1 +1051 1 +1052 2 +1054 1 +1055 3 +1056 2 +1057 1 +1059 1 +1061 2 +1063 1 +1064 1 +1065 1 +1066 1 +1067 1 +1068 1 +1069 2 +1074 1 +1075 1 +1077 1 +1078 1 +1079 1 +1082 1 +1085 1 +1088 1 +1090 1 +1091 1 +1092 2 +1094 2 +1097 2 +1098 1 +1099 2 +1101 2 +1102 1 +1104 1 +1105 1 +1107 1 +1109 1 +1111 2 +1112 1 +1114 2 +1115 2 +1116 2 +1117 1 +1118 1 +1119 1 +1120 1 +1122 1 +1123 1 +1127 1 +1128 3 +1132 2 +1138 3 +1142 1 +1145 4 +1150 1 +1153 2 +1154 1 +1158 1 +1159 1 +1163 1 +1165 1 +1169 2 +1174 1 +1176 1 +1177 1 +1178 2 +1179 1 +1180 2 +1181 1 +1182 1 +1183 2 +1185 1 +1187 1 +1191 2 +1193 1 +1195 3 +1196 1 +1201 3 +1203 1 +1206 1 +1210 1 +1213 1 +1214 1 +1215 2 +1218 1 +1220 1 +1221 1 +1225 1 +1226 1 +1233 2 +1241 1 +1243 1 +1249 1 +1250 2 +1251 1 +1254 1 +1255 2 +1260 1 +1268 1 +1270 1 +1273 1 +1274 1 +1277 1 +1284 1 +1287 1 +1291 1 +1292 2 +1294 1 +1295 2 +1297 1 +1298 1 +1301 1 +1307 1 +1308 3 +1311 2 +1313 1 +1316 1 +1321 1 +1324 1 +1325 1 +1330 1 +1333 1 +1334 1 +1338 2 +1340 1 +1341 1 +1342 1 +1343 1 +1345 1 +1355 1 +1357 1 +1360 2 +1375 1 +1376 1 +1380 1 +1383 1 +1387 1 +1389 1 +1393 1 +1394 1 +1396 1 +1398 1 +1410 1 +1414 1 +1419 1 +1425 1 +1434 1 +1435 1 +1438 1 +1439 1 +1447 1 +1455 2 +1460 1 +1461 1 +1463 1 +1466 1 +1470 1 +1473 1 +1478 1 +1480 1 +1483 1 +1484 1 +1485 2 +1492 2 +1499 1 +1509 1 +1512 1 +1513 1 +1523 1 +1524 1 +1525 2 +1529 1 +1539 1 +1544 1 +1568 1 +1584 1 +1591 1 +1598 1 +1600 1 +1604 1 +1614 1 +1617 1 +1621 1 +1622 1 +1626 1 +1638 1 +1648 1 +1658 1 +1661 1 +1679 1 +1682 1 +1693 1 +1700 1 +1705 1 +1707 1 +1722 1 +1728 1 +1758 1 +1762 1 +1763 1 +1775 1 +1776 1 +1801 1 +1810 1 +1812 1 +1827 1 +1834 1 +1846 1 +1847 1 +1848 1 +1851 1 +1862 1 +1866 1 +1877 2 +1884 1 +1888 1 +1903 1 +1912 1 +1925 1 +1938 1 +1955 1 +1998 1 +2054 1 +2058 1 +2065 1 +2069 1 +2076 1 +2089 1 +2104 1 +2111 1 +2133 1 +2138 1 +2156 1 +2204 1 +2212 1 +2237 1 +2246 2 +2298 1 +2304 1 +2360 1 +2400 1 +2481 1 +2544 1 +2586 1 +2622 1 +2666 1 +2682 1 +2725 1 +2920 1 +3997 1 +4019 1 +5211 1 +12 19 +14 1 +16 401 +18 2 +20 421 +22 557 +24 625 +26 50 +28 4481 +30 52 +32 550 +34 5840 +36 4644 +38 87 +40 5794 +41 33 +42 571 +44 11805 +46 4711 +47 7 +48 597 +49 12 +50 678 +51 2 +52 14715 +53 3 +54 7322 +55 3 +56 508 +57 39 +58 3486 +59 11 +60 8974 +61 45 +62 1276 +63 4 +64 15693 +65 15 +66 657 +67 13 +68 6409 +69 10 +70 3188 +71 25 +72 1889 +73 27 +74 10370 +75 9 +76 12432 +77 23 +78 520 +79 15 +80 1534 +81 29 +82 2944 +83 23 +84 12071 +85 36 +86 1502 +87 10 +88 10978 +89 11 +90 889 +91 16 +92 4571 +93 17 +94 7855 +95 21 +96 2271 +97 33 +98 1423 +99 15 +100 11096 +101 21 +102 4082 +103 13 +104 5442 +105 25 +106 2113 +107 26 +108 3779 +109 43 +110 1294 +111 29 +112 7860 +113 29 +114 4965 +115 22 +116 7898 +117 25 +118 1772 +119 28 +120 1149 +121 38 +122 1483 +123 32 +124 10572 +125 25 +126 1147 +127 31 +128 1699 +129 22 +130 5533 +131 22 +132 4669 +133 34 +134 3777 +135 10 +136 5412 +137 21 +138 855 +139 26 +140 2485 +141 46 +142 1970 +143 27 +144 6565 +145 40 +146 933 +147 15 +148 7923 +149 16 +150 735 +151 23 +152 1111 +153 33 +154 3714 +155 27 +156 2445 +157 30 +158 3367 +159 10 +160 4646 +161 27 +162 990 +163 23 +164 5679 +165 25 +166 2186 +167 17 +168 899 +169 32 +170 1034 +171 22 +172 6185 +173 32 +174 2685 +175 17 +176 1354 +177 38 +178 1460 +179 15 +180 3478 +181 20 +182 958 +183 20 +184 6055 +185 23 +186 2180 +187 15 +188 1416 +189 30 +190 1284 +191 22 +192 1341 +193 21 +194 2413 +195 18 +196 4984 +197 13 +198 830 +199 22 +200 1834 +201 19 +202 2238 +203 9 +204 3050 +205 22 +206 616 +207 17 +208 2892 +209 22 +210 711 +211 30 +212 2631 +213 19 +214 3341 +215 21 +216 987 +217 26 +218 823 +219 9 +220 3588 +221 20 +222 692 +223 7 +224 2925 +225 31 +226 1075 +227 16 +228 2909 +229 18 +230 673 +231 20 +232 2215 +233 14 +234 1584 +235 21 +236 1292 +237 29 +238 1647 +239 25 +240 1014 +241 30 +242 1648 +243 19 +244 4465 +245 10 +246 787 +247 11 +248 480 +249 25 +250 842 +251 15 +252 1219 +253 23 +254 1508 +255 8 +256 3525 +257 16 +258 490 +259 12 +260 1678 +261 14 +262 822 +263 16 +264 1729 +265 28 +266 604 +267 11 +268 2572 +269 7 +270 1242 +271 15 +272 725 +273 18 +274 1983 +275 13 +276 1662 +277 19 +278 491 +279 12 +280 1586 +281 14 +282 563 +283 10 +284 2363 +285 10 +286 656 +287 14 +288 725 +289 28 +290 871 +291 9 +292 2606 +293 12 +294 961 +295 9 +296 478 +297 13 +298 1252 +299 10 +300 736 +301 19 +302 466 +303 13 +304 2254 +305 12 +306 486 +307 14 +308 1145 +309 13 +310 955 +311 13 +312 1235 +313 13 +314 931 +315 14 +316 1768 +317 11 +318 330 +319 10 +320 539 +321 23 +322 570 +323 12 +324 1789 +325 13 +326 884 +327 5 +328 1422 +329 14 +330 317 +331 11 +332 509 +333 13 +334 1062 +335 12 +336 577 +337 27 +338 378 +339 10 +340 2313 +341 9 +342 391 +343 13 +344 894 +345 17 +346 664 +347 9 +348 453 +349 6 +350 363 +351 15 +352 1115 +353 13 +354 1054 +355 8 +356 1108 +357 12 +358 354 +359 7 +360 363 +361 16 +362 344 +363 11 +364 1734 +365 12 +366 265 +367 10 +368 969 +369 16 +370 316 +371 12 +372 757 +373 7 +374 563 +375 15 +376 857 +377 9 +378 469 +379 9 +380 385 +381 12 +382 921 +383 15 +384 764 +385 14 +386 246 +387 6 +388 1108 +389 14 +390 230 +391 8 +392 266 +393 11 +394 641 +395 8 +396 719 +397 9 +398 243 +399 4 +400 1108 +401 7 +402 229 +403 7 +404 903 +405 7 +406 257 +407 12 +408 244 +409 3 +410 541 +411 6 +412 744 +413 8 +414 419 +415 8 +416 388 +417 19 +418 470 +419 14 +420 612 +421 6 +422 342 +423 3 +424 1179 +425 3 +426 116 +427 14 +428 207 +429 6 +430 255 +431 4 +432 288 +433 12 +434 343 +435 6 +436 1015 +437 3 +438 538 +439 10 +440 194 +441 6 +442 188 +443 15 +444 524 +445 7 +446 214 +447 7 +448 574 +449 6 +450 214 +451 5 +452 635 +453 9 +454 464 +455 5 +456 205 +457 9 +458 163 +459 2 +460 558 +461 4 +462 171 +463 14 +464 444 +465 11 +466 543 +467 5 +468 388 +469 6 +470 141 +471 4 +472 647 +473 3 +474 210 +475 4 +476 193 +477 7 +478 195 +479 7 +480 443 +481 10 +482 198 +483 3 +484 816 +485 6 +486 128 +487 9 +488 215 +489 9 +490 328 +491 7 +492 158 +493 11 +494 335 +495 8 +496 435 +497 6 +498 174 +499 1 +500 373 +501 5 +502 140 +503 7 +504 330 +505 9 +506 149 +507 5 +508 642 +509 3 +510 179 +511 3 +512 159 +513 8 +514 204 +515 7 +516 306 +517 4 +518 110 +519 5 +520 326 +521 6 +522 305 +523 6 +524 294 +525 7 +526 268 +527 5 +528 149 +529 4 +530 133 +531 2 +532 513 +533 10 +534 116 +535 5 +536 258 +537 4 +538 113 +539 4 +540 138 +541 6 +542 116 +544 485 +545 4 +546 93 +547 9 +548 299 +549 3 +550 256 +551 6 +552 92 +553 3 +554 175 +555 6 +556 253 +557 7 +558 95 +559 2 +560 128 +561 4 +562 206 +563 2 +564 465 +565 3 +566 69 +567 3 +568 157 +569 7 +570 97 +571 8 +572 118 +573 5 +574 130 +575 4 +576 301 +577 6 +578 177 +579 2 +580 397 +581 3 +582 80 +583 1 +584 128 +585 5 +586 52 +587 2 +588 72 +589 1 +590 84 +591 6 +592 323 +593 11 +594 77 +595 5 +596 205 +597 1 +598 244 +599 4 +600 69 +601 3 +602 89 +603 5 +604 254 +605 6 +606 147 +607 3 +608 83 +609 3 +610 77 +611 3 +612 194 +613 1 +614 98 +615 3 +616 243 +617 3 +618 50 +619 8 +620 188 +621 4 +622 67 +623 4 +624 123 +625 2 +626 50 +627 1 +628 239 +629 2 +630 51 +631 4 +632 65 +633 5 +634 188 +636 81 +637 3 +638 46 +639 3 +640 103 +641 1 +642 136 +643 3 +644 188 +645 3 +646 58 +648 122 +649 4 +650 47 +651 2 +652 155 +653 4 +654 71 +655 1 +656 71 +657 3 +658 50 +659 2 +660 177 +661 5 +662 66 +663 2 +664 183 +665 3 +666 50 +667 2 +668 53 +669 2 +670 115 +672 66 +673 2 +674 47 +675 1 +676 197 +677 2 +678 46 +679 3 +680 95 +681 3 +682 46 +683 3 +684 107 +685 1 +686 86 +687 2 +688 158 +689 4 +690 51 +691 1 +692 80 +694 56 +695 4 +696 40 +698 43 +699 3 +700 95 +701 2 +702 51 +703 2 +704 133 +705 1 +706 100 +707 2 +708 121 +709 2 +710 15 +711 3 +712 35 +713 2 +714 20 +715 3 +716 37 +717 2 +718 78 +720 55 +721 1 +722 42 +723 2 +724 218 +725 3 +726 23 +727 2 +728 26 +729 1 +730 64 +731 2 +732 65 +734 24 +735 2 +736 53 +737 1 +738 32 +739 1 +740 60 +742 81 +743 1 +744 77 +745 1 +746 47 +747 1 +748 62 +749 1 +750 19 +751 1 +752 86 +753 3 +754 40 +756 55 +757 2 +758 38 +759 1 +760 101 +761 1 +762 22 +764 67 +765 2 +766 35 +767 1 +768 38 +769 1 +770 22 +771 1 +772 82 +773 1 +774 73 +776 29 +777 1 +778 55 +780 23 +781 1 +782 16 +784 84 +785 3 +786 28 +788 59 +789 1 +790 33 +791 3 +792 24 +794 13 +795 1 +796 110 +797 2 +798 15 +800 22 +801 3 +802 29 +803 1 +804 87 +806 21 +808 29 +810 48 +812 28 +813 1 +814 58 +815 1 +816 48 +817 1 +818 31 +819 1 +820 66 +822 17 +823 2 +824 58 +826 10 +827 2 +828 25 +829 1 +830 29 +831 1 +832 63 +833 1 +834 26 +835 3 +836 52 +837 1 +838 18 +840 27 +841 2 +842 12 +843 1 +844 83 +845 1 +846 7 +847 1 +848 10 +850 26 +852 25 +853 1 +854 15 +856 27 +858 32 +859 1 +860 15 +862 43 +864 32 +865 1 +866 6 +868 39 +870 11 +872 25 +873 1 +874 10 +875 1 +876 20 +877 2 +878 19 +879 1 +880 30 +882 11 +884 53 +886 25 +887 1 +888 28 +890 6 +892 36 +894 10 +896 13 +898 14 +900 31 +902 14 +903 2 +904 43 +906 25 +908 9 +910 11 +911 1 +912 16 +913 1 +914 24 +916 27 +918 6 +920 15 +922 27 +923 1 +924 23 +926 13 +928 42 +929 1 +930 3 +932 27 +934 17 +936 8 +937 1 +938 11 +940 33 +942 4 +943 1 +944 18 +946 15 +948 13 +950 18 +952 12 +954 11 +956 21 +958 10 +960 13 +962 5 +964 32 +966 13 +968 8 +970 8 +971 1 +972 23 +973 2 +974 12 +975 1 +976 22 +978 7 +979 1 +980 14 +982 8 +984 22 +985 1 +986 6 +988 17 +989 1 +990 6 +992 13 +994 19 +996 11 +998 4 +1000 9 +1002 2 +1004 14 +1006 5 +1008 3 +1010 9 +1012 29 +1014 6 +1016 22 +1017 1 +1018 8 +1019 1 +1020 7 +1022 6 +1023 1 +1024 10 +1026 2 +1028 8 +1030 11 +1031 2 +1032 8 +1034 9 +1036 13 +1038 12 +1040 12 +1042 3 +1044 12 +1046 3 +1048 11 +1050 2 +1051 1 +1052 2 +1054 11 +1056 6 +1058 8 +1059 1 +1060 23 +1062 6 +1063 1 +1064 8 +1066 3 +1068 6 +1070 8 +1071 1 +1072 5 +1074 3 +1076 5 +1078 3 +1080 11 +1081 1 +1082 7 +1084 18 +1086 4 +1087 1 +1088 3 +1090 3 +1092 7 +1094 3 +1096 12 +1098 6 +1099 1 +1100 2 +1102 6 +1104 14 +1106 3 +1108 6 +1110 5 +1112 2 +1114 8 +1116 3 +1118 3 +1120 7 +1122 10 +1124 6 +1126 8 +1128 1 +1130 4 +1132 3 +1134 2 +1136 5 +1138 5 +1140 8 +1142 3 +1144 7 +1146 3 +1148 11 +1150 1 +1152 5 +1154 1 +1156 5 +1158 1 +1160 5 +1162 3 +1164 6 +1165 1 +1166 1 +1168 4 +1169 1 +1170 3 +1171 1 +1172 2 +1174 5 +1176 3 +1177 1 +1180 8 +1182 2 +1184 4 +1186 2 +1188 3 +1190 2 +1192 5 +1194 6 +1196 1 +1198 2 +1200 2 +1204 10 +1206 2 +1208 9 +1210 1 +1214 6 +1216 3 +1218 4 +1220 9 +1221 2 +1222 1 +1224 5 +1226 4 +1228 8 +1230 1 +1232 1 +1234 3 +1236 5 +1240 3 +1242 1 +1244 3 +1245 1 +1246 4 +1248 6 +1250 2 +1252 7 +1256 3 +1258 2 +1260 2 +1262 3 +1264 4 +1265 1 +1266 1 +1270 1 +1271 1 +1272 2 +1274 3 +1276 3 +1278 1 +1280 3 +1284 1 +1286 1 +1290 1 +1292 3 +1294 1 +1296 7 +1300 2 +1302 4 +1304 3 +1306 2 +1308 2 +1312 1 +1314 1 +1316 3 +1318 2 +1320 1 +1324 8 +1326 1 +1330 1 +1331 1 +1336 2 +1338 1 +1340 3 +1341 1 +1344 1 +1346 2 +1347 1 +1348 3 +1352 1 +1354 2 +1356 1 +1358 1 +1360 3 +1362 1 +1364 4 +1366 1 +1370 1 +1372 3 +1380 2 +1384 2 +1388 2 +1390 2 +1392 2 +1394 1 +1396 1 +1398 1 +1400 2 +1402 1 +1404 1 +1406 1 +1410 1 +1412 5 +1418 1 +1420 1 +1424 1 +1432 2 +1434 2 +1442 3 +1444 5 +1448 1 +1454 1 +1456 1 +1460 3 +1462 4 +1468 1 +1474 1 +1476 1 +1478 2 +1480 1 +1486 2 +1488 1 +1492 1 +1496 1 +1500 3 +1503 1 +1506 1 +1512 2 +1516 1 +1522 1 +1524 2 +1534 4 +1536 1 +1538 1 +1540 2 +1544 2 +1548 1 +1556 1 +1560 1 +1562 1 +1564 2 +1566 1 +1568 1 +1570 1 +1572 1 +1576 1 +1590 1 +1594 1 +1604 1 +1608 1 +1614 1 +1622 1 +1624 2 +1628 1 +1629 1 +1636 1 +1642 1 +1654 2 +1660 1 +1664 1 +1670 1 +1684 4 +1698 1 +1732 3 +1742 1 +1752 1 +1760 1 +1764 1 +1772 2 +1798 1 +1808 1 +1820 1 +1852 1 +1856 1 +1874 1 +1902 1 +1908 1 +1952 1 +2004 1 +2018 1 +2020 1 +2028 1 +2174 1 +2233 1 +2244 1 +2280 1 +2290 1 +2352 1 +2604 1 +4190 1 diff --git a/ppocr/utils/network.py b/ppocr/utils/network.py new file mode 100644 index 0000000000000000000000000000000000000000..453abb693d4c0ed370c1031b677d5bf51661add9 --- /dev/null +++ b/ppocr/utils/network.py @@ -0,0 +1,82 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import tarfile +import requests +from tqdm import tqdm + +from ppocr.utils.logging import get_logger + + +def download_with_progressbar(url, save_path): + logger = get_logger() + response = requests.get(url, stream=True) + total_size_in_bytes = int(response.headers.get('content-length', 0)) + block_size = 1024 # 1 Kibibyte + progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) + with open(save_path, 'wb') as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes: + logger.error("Something went wrong while downloading models") + sys.exit(0) + + +def maybe_download(model_storage_directory, url): + # using custom model + tar_file_name_list = [ + 'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel' + ] + if not os.path.exists( + os.path.join(model_storage_directory, 'inference.pdiparams') + ) or not os.path.exists( + os.path.join(model_storage_directory, 'inference.pdmodel')): + assert url.endswith('.tar'), 'Only supports tar compressed package' + tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) + print('download {} to {}'.format(url, tmp_path)) + os.makedirs(model_storage_directory, exist_ok=True) + download_with_progressbar(url, tmp_path) + with tarfile.open(tmp_path, 'r') as tarObj: + for member in tarObj.getmembers(): + filename = None + for tar_file_name in tar_file_name_list: + if tar_file_name in member.name: + filename = tar_file_name + if filename is None: + continue + file = tarObj.extractfile(member) + with open( + os.path.join(model_storage_directory, filename), + 'wb') as f: + f.write(file.read()) + os.remove(tmp_path) + + +def is_link(s): + return s is not None and s.startswith('http') + + +def confirm_model_dir_url(model_dir, default_model_dir, default_url): + url = default_url + if model_dir is None or is_link(model_dir): + if is_link(model_dir): + url = model_dir + file_name = url.split('/')[-1][:-4] + model_dir = default_model_dir + model_dir = os.path.join(model_dir, file_name) + return model_dir, url diff --git a/ppstructure/MANIFEST.in b/ppstructure/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..2961e722b7cebe8e1912be2dd903fcdecb694019 --- /dev/null +++ b/ppstructure/MANIFEST.in @@ -0,0 +1,9 @@ +include LICENSE +include README.md + +recursive-include ppocr/utils *.txt utility.py logging.py network.py +recursive-include ppocr/data/ *.py +recursive-include ppocr/postprocess *.py +recursive-include tools/infer *.py +recursive-include ppstructure *.py + diff --git a/ppstructure/README_ch.md b/ppstructure/README_ch.md index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..22505ad83c6dc58adf472f3db94cbf608b9bbd01 100644 --- a/ppstructure/README_ch.md +++ b/ppstructure/README_ch.md @@ -0,0 +1,30 @@ +# TableStructurer + +1. 代码使用 +```python +import cv2 +from paddlestructure import PaddleStructure,draw_result + +table_engine = PaddleStructure( + output='./output/table', + show_log=True) + +img_path = '../doc/table/1.png' +img = cv2.imread(img_path) +result = table_engine(img) +for line in result: + print(line) + +from PIL import Image + +font_path = 'path/tp/PaddleOCR/doc/fonts/simfang.ttf' +image = Image.open(img_path).convert('RGB') +im_show = draw_result(image, result,font_path=font_path) +im_show = Image.fromarray(im_show) +im_show.save('result.jpg') +``` + +2. 命令行使用 +```bash +paddlestructure --image_dir=../doc/table/1.png +``` diff --git a/ppstructure/__init__.py b/ppstructure/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7055bee443fb86648b80bcb892778a114bc47d71 --- /dev/null +++ b/ppstructure/__init__.py @@ -0,0 +1,17 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .paddlestructure import PaddleStructure, draw_result, to_excel + +__all__ = ['PaddleStructure', 'draw_result', 'to_excel'] diff --git a/ppstructure/layout/README.md b/ppstructure/layout/README.md deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/ppstructure/layout/README_ch.md b/ppstructure/layout/README_ch.md deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/ppstructure/paddlestructure.py b/ppstructure/paddlestructure.py new file mode 100644 index 0000000000000000000000000000000000000000..57a53d6496f66771f1f6f7628751b4f0ac0fc3b5 --- /dev/null +++ b/ppstructure/paddlestructure.py @@ -0,0 +1,148 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import sys + +__dir__ = os.path.dirname(__file__) +sys.path.append(__dir__) +sys.path.append(os.path.join(__dir__, '..')) + +import cv2 +import numpy as np +from pathlib import Path + +from ppocr.utils.logging import get_logger +from ppstructure.predict_system import OCRSystem, save_res +from ppstructure.table.predict_table import to_excel +from ppstructure.utility import init_args, draw_result + +logger = get_logger() +from ppocr.utils.utility import check_and_read_gif, get_image_file_list +from ppocr.utils.network import maybe_download, download_with_progressbar, confirm_model_dir_url, is_link + +__all__ = ['PaddleStructure', 'draw_result', 'to_excel'] + +VERSION = '2.1' +BASE_DIR = os.path.expanduser("~/.paddlestructure/") + +model_urls = { + 'det': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar', + 'rec': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar', + 'structure': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar' + +} + + +def parse_args(mMain=True): + import argparse + parser = init_args() + parser.add_help = mMain + + for action in parser._actions: + if action.dest in ['rec_char_dict_path', 'structure_char_dict_path']: + action.default = None + if mMain: + return parser.parse_args() + else: + inference_args_dict = {} + for action in parser._actions: + inference_args_dict[action.dest] = action.default + return argparse.Namespace(**inference_args_dict) + + +class PaddleStructure(OCRSystem): + def __init__(self, **kwargs): + params = parse_args(mMain=False) + params.__dict__.update(**kwargs) + if params.show_log: + logger.setLevel(logging.DEBUG) + params.use_angle_cls = False + # init model dir + params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir, + os.path.join(BASE_DIR, VERSION, 'det'), + model_urls['det']) + params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir, + os.path.join(BASE_DIR, VERSION, 'rec'), + model_urls['rec']) + params.structure_model_dir, structure_url = confirm_model_dir_url(params.structure_model_dir, + os.path.join(BASE_DIR, VERSION, 'structure'), + model_urls['structure']) + # download model + maybe_download(params.det_model_dir, det_url) + maybe_download(params.rec_model_dir, rec_url) + maybe_download(params.structure_model_dir, structure_url) + + if params.rec_char_dict_path is None: + params.rec_char_type = 'EN' + if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')): + params.rec_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt') + else: + params.rec_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_dict.txt') + if params.structure_char_dict_path is None: + if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')): + params.structure_char_dict_path = str( + Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt') + else: + params.structure_char_dict_path = str( + Path(__file__).parent.parent / 'ppocr/utils/dict/table_structure_dict.txt') + + print(params) + super().__init__(params) + + def __call__(self, img): + if isinstance(img, str): + # download net image + if img.startswith('http'): + download_with_progressbar(img, 'tmp.jpg') + img = 'tmp.jpg' + image_file = img + img, flag = check_and_read_gif(image_file) + if not flag: + with open(image_file, 'rb') as f: + np_arr = np.frombuffer(f.read(), dtype=np.uint8) + img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + if img is None: + logger.error("error in loading image:{}".format(image_file)) + return None + if isinstance(img, np.ndarray) and len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + res = super().__call__(img) + return res + + +def main(): + # for cmd + args = parse_args(mMain=True) + image_dir = args.image_dir + save_folder = args.output + if image_dir.startswith('http'): + download_with_progressbar(image_dir, 'tmp.jpg') + image_file_list = ['tmp.jpg'] + else: + image_file_list = get_image_file_list(args.image_dir) + if len(image_file_list) == 0: + logger.error('no images find in {}'.format(args.image_dir)) + return + + structure_engine = PaddleStructure(**(args.__dict__)) + for img_path in image_file_list: + img_name = os.path.basename(img_path).split('.')[0] + logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10)) + result = structure_engine(img_path) + for item in result: + logger.info(item['res']) + save_res(result, save_folder, img_name) + logger.info('result save to {}'.format(os.path.join(save_folder, img_name))) \ No newline at end of file diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..2cdfcce2eb3ad4abe4407f781eb99e3591ecebde 100644 --- a/ppstructure/predict_system.py +++ b/ppstructure/predict_system.py @@ -0,0 +1,132 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import subprocess + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' +import cv2 +import numpy as np +import time + +import layoutparser as lp + +from ppocr.utils.utility import get_image_file_list, check_and_read_gif +from ppocr.utils.logging import get_logger +from tools.infer.predict_system import TextSystem +from ppstructure.table.predict_table import TableSystem, to_excel +from ppstructure.utility import parse_args,draw_result + +logger = get_logger() + + +class OCRSystem(object): + def __init__(self, args): + args.det_limit_type = 'resize_long' + args.drop_score = 0 + self.text_system = TextSystem(args) + self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer) + self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config", + threshold=0.5, enable_mkldnn=args.enable_mkldnn, + enforce_cpu=not args.use_gpu, thread_num=args.cpu_threads) + self.use_angle_cls = args.use_angle_cls + self.drop_score = args.drop_score + + def __call__(self, img): + ori_im = img.copy() + layout_res = self.table_layout.detect(img[..., ::-1]) + res_list = [] + for region in layout_res: + x1, y1, x2, y2 = region.coordinates + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + roi_img = ori_im[y1:y2, x1:x2, :] + if region.type == 'Table': + res = self.table_system(roi_img) + elif region.type == 'Figure': + continue + else: + filter_boxes, filter_rec_res = self.text_system(roi_img) + filter_boxes = [x + [x1, y1] for x in filter_boxes] + filter_boxes = [x.reshape(-1).tolist() for x in filter_boxes] + + res = (filter_boxes, filter_rec_res) + res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res}) + return res_list + +def save_res(res, save_folder, img_name): + excel_save_folder = os.path.join(save_folder, img_name) + os.makedirs(excel_save_folder, exist_ok=True) + # save res + for region in res: + if region['type'] == 'Table': + excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox'])) + to_excel(region['res'], excel_path) + elif region['type'] == 'Figure': + pass + else: + with open(os.path.join(excel_save_folder, 'res.txt'), 'a', encoding='utf8') as f: + for box, rec_res in zip(region['res'][0], region['res'][1]): + f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res)) + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + image_file_list = image_file_list + image_file_list = image_file_list[args.process_id::args.total_process_num] + save_folder = args.output + os.makedirs(save_folder, exist_ok=True) + + structure_sys = OCRSystem(args) + img_num = len(image_file_list) + for i, image_file in enumerate(image_file_list): + logger.info("[{}/{}] {}".format(i, img_num, image_file)) + img, flag = check_and_read_gif(image_file) + img_name = os.path.basename(image_file).split('.')[0] + + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.error("error in loading image:{}".format(image_file)) + continue + starttime = time.time() + res = structure_sys(img) + save_res(res, save_folder, img_name) + draw_img = draw_result(img,res, args.vis_font_path) + cv2.imwrite(os.path.join(save_folder, img_name, 'show.jpg'), draw_img) + logger.info('result save to {}'.format(os.path.join(save_folder, img_name))) + elapse = time.time() - starttime + logger.info("Predict time : {:.3f}s".format(elapse)) + + +if __name__ == "__main__": + args = parse_args() + if args.use_mp: + p_list = [] + total_process_num = args.total_process_num + for process_id in range(total_process_num): + cmd = [sys.executable, "-u"] + sys.argv + [ + "--process_id={}".format(process_id), + "--use_mp={}".format(False) + ] + p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout) + p_list.append(p) + for p in p_list: + p.wait() + else: + main(args) diff --git a/ppstructure/setup.py b/ppstructure/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..8e68b2e44140f6ad5a13661349666d17cfe45524 --- /dev/null +++ b/ppstructure/setup.py @@ -0,0 +1,72 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from setuptools import setup +from io import open +import shutil + +with open('../requirements.txt', encoding="utf-8-sig") as f: + requirements = f.readlines() + requirements.append('tqdm') + requirements.append('layoutparser') + requirements.append('iopath') + + +def readme(): + with open('README_ch.md', encoding="utf-8-sig") as f: + README = f.read() + return README + + +shutil.copytree('../ppstructure/table', './ppstructure/table') +shutil.copyfile('../ppstructure/predict_system.py', './ppstructure/predict_system.py') +shutil.copyfile('../ppstructure/utility.py', './ppstructure/utility.py') +shutil.copytree('../ppocr', './ppocr') +shutil.copytree('../tools', './tools') +shutil.copyfile('../LICENSE', './LICENSE') + +setup( + name='paddlestructure', + packages=['paddlestructure'], + package_dir={'paddlestructure': ''}, + include_package_data=True, + entry_points={"console_scripts": ["paddlestructure= paddlestructure.paddlestructure:main"]}, + version='1.0', + install_requires=requirements, + license='Apache License 2.0', + description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices', + long_description=readme(), + long_description_content_type='text/markdown', + url='https://github.com/PaddlePaddle/PaddleOCR', + download_url='https://github.com/PaddlePaddle/PaddleOCR.git', + keywords=[ + 'ocr textdetection textrecognition paddleocr crnn east star-net rosetta ocrlite db chineseocr chinesetextdetection chinesetextrecognition' + ], + classifiers=[ + 'Intended Audience :: Developers', 'Operating System :: OS Independent', + 'Natural Language :: Chinese (Simplified)', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.2', + 'Programming Language :: Python :: 3.3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', 'Topic :: Utilities' + ], ) + +shutil.rmtree('ppocr') +shutil.rmtree('tools') +shutil.rmtree('ppstructure') +os.remove('LICENSE') diff --git a/ppstructure/table/README_ch.md b/ppstructure/table/README_ch.md index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..105231068a99eb6c012a125ba3fb65934c5d4ac6 100644 --- a/ppstructure/table/README_ch.md +++ b/ppstructure/table/README_ch.md @@ -0,0 +1,15 @@ +# 表格结构和内容预测 + +先cd到PaddleOCR/ppstructure目录下 + +预测 +```python +python3 table/predict_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs/PMC3006023_004_00.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --table_output ../output/table +``` +运行完成后,每张图片的excel表格会保存到table_output字段指定的目录下 + +评估 + +```python +python3 table/eval_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json +``` diff --git a/ppstructure/table/__init__.py b/ppstructure/table/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d11e265597c7c8e39098a228108da3bb954b892 --- /dev/null +++ b/ppstructure/table/__init__.py @@ -0,0 +1,13 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/ppstructure/table/eval_table.py b/ppstructure/table/eval_table.py new file mode 100755 index 0000000000000000000000000000000000000000..1bcbaa8d0d0b2669828dc6b19c3370a30c522ede --- /dev/null +++ b/ppstructure/table/eval_table.py @@ -0,0 +1,69 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +import cv2 +import json +from tqdm import tqdm +from ppstructure.table.table_metric import TEDS +from ppstructure.table.predict_table import TableSystem +from ppstructure.utility import init_args + + +def parse_args(): + parser = init_args() + parser.add_argument("--gt_path", type=str) + return parser.parse_args() + +def main(gt_path, img_root, args): + teds = TEDS(n_jobs=16) + + text_sys = TableSystem(args) + jsons_gt = json.load(open(gt_path)) # gt + pred_htmls = [] + gt_htmls = [] + for img_name in tqdm(jsons_gt): + # read image + img = cv2.imread(os.path.join(img_root,img_name)) + pred_html = text_sys(img) + pred_htmls.append(pred_html) + + gt_structures, gt_bboxes, gt_contents, contents_with_block = jsons_gt[img_name] + gt_html, gt = get_gt_html(gt_structures, contents_with_block) + gt_htmls.append(gt_html) + scores = teds.batch_evaluate_html(gt_htmls, pred_htmls) + print('teds:', sum(scores) / len(scores)) + + +def get_gt_html(gt_structures, contents_with_block): + end_html = [] + td_index = 0 + for tag in gt_structures: + if '' in tag: + if contents_with_block[td_index] != []: + end_html.extend(contents_with_block[td_index]) + end_html.append(tag) + td_index += 1 + else: + end_html.append(tag) + return ''.join(end_html), end_html + + +if __name__ == '__main__': + args = parse_args() + main(args.gt_path,args.image_dir, args) diff --git a/ppstructure/table/matcher.py b/ppstructure/table/matcher.py new file mode 100755 index 0000000000000000000000000000000000000000..c3b56384403f5fd92a8db4b4bb378a6d55e5a76c --- /dev/null +++ b/ppstructure/table/matcher.py @@ -0,0 +1,192 @@ +import json +def distance(box_1, box_2): + x1, y1, x2, y2 = box_1 + x3, y3, x4, y4 = box_2 + dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2) + dis_2 = abs(x3 - x1) + abs(y3 - y1) + dis_3 = abs(x4- x2) + abs(y4 - y2) + return dis + min(dis_2, dis_3) + +def compute_iou(rec1, rec2): + """ + computing IoU + :param rec1: (y0, x0, y1, x1), which reflects + (top, left, bottom, right) + :param rec2: (y0, x0, y1, x1) + :return: scala value of IoU + """ + # computing area of each rectangles + S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) + S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) + + # computing the sum_area + sum_area = S_rec1 + S_rec2 + + # find the each edge of intersect rectangle + left_line = max(rec1[1], rec2[1]) + right_line = min(rec1[3], rec2[3]) + top_line = max(rec1[0], rec2[0]) + bottom_line = min(rec1[2], rec2[2]) + + # judge if there is an intersect + if left_line >= right_line or top_line >= bottom_line: + return 0.0 + else: + intersect = (right_line - left_line) * (bottom_line - top_line) + return (intersect / (sum_area - intersect))*1.0 + + + +def matcher_merge(ocr_bboxes, pred_bboxes): + all_dis = [] + ious = [] + matched = {} + for i, gt_box in enumerate(ocr_bboxes): + distances = [] + for j, pred_box in enumerate(pred_bboxes): + # compute l1 distence and IOU between two boxes + distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) + sorted_distances = distances.copy() + # select nearest cell + sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0])) + if distances.index(sorted_distances[0]) not in matched.keys(): + matched[distances.index(sorted_distances[0])] = [i] + else: + matched[distances.index(sorted_distances[0])].append(i) + return matched#, sum(ious) / len(ious) + +def complex_num(pred_bboxes): + complex_nums = [] + for bbox in pred_bboxes: + distances = [] + temp_ious = [] + for pred_bbox in pred_bboxes: + if bbox != pred_bbox: + distances.append(distance(bbox, pred_bbox)) + temp_ious.append(compute_iou(bbox, pred_bbox)) + complex_nums.append(temp_ious[distances.index(min(distances))]) + return sum(complex_nums) / len(complex_nums) + +def get_rows(pred_bboxes): + pre_bbox = pred_bboxes[0] + res = [] + step = 0 + for i in range(len(pred_bboxes)): + bbox = pred_bboxes[i] + if bbox[1] - pre_bbox[1] > 2 or bbox[0] - pre_bbox[0] < 0: + break + else: + res.append(bbox) + step += 1 + for i in range(step): + pred_bboxes.pop(0) + return res, pred_bboxes +def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上 + ys_1 = [] + ys_2 = [] + for box in pred_bboxes: + ys_1.append(box[1]) + ys_2.append(box[3]) + min_y_1 = sum(ys_1) / len(ys_1) + min_y_2 = sum(ys_2) / len(ys_2) + re_boxes = [] + for box in pred_bboxes: + box[1] = min_y_1 + box[3] = min_y_2 + re_boxes.append(box) + return re_boxes + +def matcher_refine_row(gt_bboxes, pred_bboxes): + before_refine_pred_bboxes = pred_bboxes.copy() + pred_bboxes = [] + while(len(before_refine_pred_bboxes) != 0): + row_bboxes, before_refine_pred_bboxes = get_rows(before_refine_pred_bboxes) + print(row_bboxes) + pred_bboxes.extend(refine_rows(row_bboxes)) + all_dis = [] + ious = [] + matched = {} + for i, gt_box in enumerate(gt_bboxes): + distances = [] + #temp_ious = [] + for j, pred_box in enumerate(pred_bboxes): + distances.append(distance(gt_box, pred_box)) + #temp_ious.append(compute_iou(gt_box, pred_box)) + #all_dis.append(min(distances)) + #ious.append(temp_ious[distances.index(min(distances))]) + if distances.index(min(distances)) not in matched.keys(): + matched[distances.index(min(distances))] = [i] + else: + matched[distances.index(min(distances))].append(i) + return matched#, sum(ious) / len(ious) + + + +#先挑选出一行,再进行匹配 +def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes): + gt_box_index = 0 + delete_gt_bboxes = gt_bboxes.copy() + match_bboxes_ready = [] + matched = {} + while(len(delete_gt_bboxes) != 0): + row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes) + row_bboxes = sorted(row_bboxes, key = lambda key: key[0]) + if len(pred_bboxes_rows) > 0: + match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) + print(row_bboxes) + for i, gt_box in enumerate(row_bboxes): + #print(gt_box) + pred_distances = [] + distances = [] + for pred_bbox in pred_bboxes: + pred_distances.append(distance(gt_box, pred_bbox)) + for j, pred_box in enumerate(match_bboxes_ready): + distances.append(distance(gt_box, pred_box)) + index = pred_distances.index(min(distances)) + #print('index', index) + if index not in matched.keys(): + matched[index] = [gt_box_index] + else: + matched[index].append(gt_box_index) + gt_box_index += 1 + return matched + +def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes): + ''' + gt_bboxes: 排序后 + pred_bboxes: + ''' + pre_bbox = gt_bboxes[0] + matched = {} + match_bboxes_ready = [] + match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) + for i, gt_box in enumerate(gt_bboxes): + + pred_distances = [] + for pred_bbox in pred_bboxes: + pred_distances.append(distance(gt_box, pred_bbox)) + distances = [] + gap_pre = gt_box[1] - pre_bbox[1] + gap_pre_1 = gt_box[0] - pre_bbox[2] + #print(gap_pre, len(pred_bboxes_rows)) + if (gap_pre_1 < 0 and len(pred_bboxes_rows) > 0): + match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) + if len(pred_bboxes_rows) == 1: + match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) + if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) > 0: + match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) + if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) == 0: + break + #print(match_bboxes_ready) + for j, pred_box in enumerate(match_bboxes_ready): + distances.append(distance(gt_box, pred_box)) + index = pred_distances.index(min(distances)) + #print(gt_box, index) + #match_bboxes_ready.pop(distances.index(min(distances))) + print(gt_box, match_bboxes_ready[distances.index(min(distances))]) + if index not in matched.keys(): + matched[index] = [i] + else: + matched[index].append(i) + pre_bbox = gt_box + return matched diff --git a/ppstructure/table/predict_structure.py b/ppstructure/table/predict_structure.py new file mode 100755 index 0000000000000000000000000000000000000000..6e680b3574ba28b439acad34424b51dfdc02078c --- /dev/null +++ b/ppstructure/table/predict_structure.py @@ -0,0 +1,141 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import numpy as np +import math +import time +import traceback +import paddle + +import tools.infer.utility as utility +from ppocr.data import create_operators, transform +from ppocr.postprocess import build_post_process +from ppocr.utils.logging import get_logger +from ppocr.utils.utility import get_image_file_list, check_and_read_gif + +logger = get_logger() + + +class TableStructurer(object): + def __init__(self, args): + pre_process_list = [{ + 'ResizeTableImage': { + 'max_len': args.structure_max_len + } + }, { + 'NormalizeImage': { + 'std': [0.229, 0.224, 0.225], + 'mean': [0.485, 0.456, 0.406], + 'scale': '1./255.', + 'order': 'hwc' + } + }, { + 'PaddingTableImage': None + }, { + 'ToCHWImage': None + }, { + 'KeepKeys': { + 'keep_keys': ['image'] + } + }] + postprocess_params = { + 'name': 'TableLabelDecode', + "character_type": args.structure_char_type, + "character_dict_path": args.structure_char_dict_path, + "max_text_length": args.structure_max_text_length, + "max_elem_length": args.structure_max_elem_length, + "max_cell_num": args.structure_max_cell_num + } + + self.preprocess_op = create_operators(pre_process_list) + self.postprocess_op = build_post_process(postprocess_params) + self.predictor, self.input_tensor, self.output_tensors = \ + utility.create_predictor(args, 'structure', logger) + + def __call__(self, img): + ori_im = img.copy() + data = {'image': img} + data = transform(data, self.preprocess_op) + img = data[0] + if img is None: + return None, 0 + img = np.expand_dims(img, axis=0) + img = img.copy() + starttime = time.time() + + self.input_tensor.copy_from_cpu(img) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + + preds = {} + preds['structure_probs'] = outputs[1] + preds['loc_preds'] = outputs[0] + + post_result = self.postprocess_op(preds) + + structure_str_list = post_result['structure_str_list'] + res_loc = post_result['res_loc'] + imgh, imgw = ori_im.shape[0:2] + res_loc_final = [] + for rno in range(len(res_loc[0])): + x0, y0, x1, y1 = res_loc[0][rno] + left = max(int(imgw * x0), 0) + top = max(int(imgh * y0), 0) + right = min(int(imgw * x1), imgw - 1) + bottom = min(int(imgh * y1), imgh - 1) + res_loc_final.append([left, top, right, bottom]) + + structure_str_list = structure_str_list[0][:-1] + structure_str_list = ['', '', ''] + structure_str_list + ['
', '', ''] + + elapse = time.time() - starttime + return (structure_str_list, res_loc_final), elapse + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + table_structurer = TableStructurer(args) + count = 0 + total_time = 0 + for image_file in image_file_list: + img, flag = check_and_read_gif(image_file) + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + structure_res, elapse = table_structurer(img) + + logger.info("result: {}".format(structure_res)) + + if count > 0: + total_time += elapse + count += 1 + logger.info("Predict time of {}: {}".format(image_file, elapse)) + + +if __name__ == "__main__": + main(utility.parse_args()) diff --git a/ppstructure/table/predict_table.py b/ppstructure/table/predict_table.py new file mode 100644 index 0000000000000000000000000000000000000000..c4edd22c3de4df5f0ba3e0a1e28a8c346a48d4ee --- /dev/null +++ b/ppstructure/table/predict_table.py @@ -0,0 +1,221 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import subprocess + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' +import cv2 +import copy +import numpy as np +import time +import tools.infer.predict_rec as predict_rec +import tools.infer.predict_det as predict_det +from ppocr.utils.utility import get_image_file_list, check_and_read_gif +from ppocr.utils.logging import get_logger +from ppstructure.table.matcher import distance, compute_iou +from ppstructure.utility import parse_args +import ppstructure.table.predict_structure as predict_strture + +logger = get_logger() + + +def expand(pix, det_box, shape): + x0, y0, x1, y1 = det_box + # print(shape) + h, w, c = shape + tmp_x0 = x0 - pix + tmp_x1 = x1 + pix + tmp_y0 = y0 - pix + tmp_y1 = y1 + pix + x0_ = tmp_x0 if tmp_x0 >= 0 else 0 + x1_ = tmp_x1 if tmp_x1 <= w else w + y0_ = tmp_y0 if tmp_y0 >= 0 else 0 + y1_ = tmp_y1 if tmp_y1 <= h else h + return x0_, y0_, x1_, y1_ + + +class TableSystem(object): + def __init__(self, args, text_detector=None, text_recognizer=None): + self.text_detector = predict_det.TextDetector(args) if text_detector is None else text_detector + self.text_recognizer = predict_rec.TextRecognizer(args) if text_recognizer is None else text_recognizer + self.table_structurer = predict_strture.TableStructurer(args) + + def __call__(self, img): + ori_im = img.copy() + structure_res, elapse = self.table_structurer(copy.deepcopy(img)) + dt_boxes, elapse = self.text_detector(copy.deepcopy(img)) + dt_boxes = sorted_boxes(dt_boxes) + + r_boxes = [] + for box in dt_boxes: + x_min = box[:, 0].min() - 1 + x_max = box[:, 0].max() + 1 + y_min = box[:, 1].min() - 1 + y_max = box[:, 1].max() + 1 + box = [x_min, y_min, x_max, y_max] + r_boxes.append(box) + dt_boxes = np.array(r_boxes) + + logger.debug("dt_boxes num : {}, elapse : {}".format( + len(dt_boxes), elapse)) + if dt_boxes is None: + return None, None + img_crop_list = [] + + for i in range(len(dt_boxes)): + det_box = dt_boxes[i] + x0, y0, x1, y1 = expand(2, det_box, ori_im.shape) + text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :] + img_crop_list.append(text_rect) + rec_res, elapse = self.text_recognizer(img_crop_list) + logger.debug("rec_res num : {}, elapse : {}".format( + len(rec_res), elapse)) + + pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res) + return pred_html + + def rebuild_table(self, structure_res, dt_boxes, rec_res): + pred_structures, pred_bboxes = structure_res + matched_index = self.match_result(dt_boxes, pred_bboxes) + pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res) + return pred_html, pred + + def match_result(self, dt_boxes, pred_bboxes): + matched = {} + for i, gt_box in enumerate(dt_boxes): + # gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])] + distances = [] + for j, pred_box in enumerate(pred_bboxes): + distances.append( + (distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) # 获取两两cell之间的L1距离和 1- IOU + sorted_distances = distances.copy() + # 根据距离和IOU挑选最"近"的cell + sorted_distances = sorted(sorted_distances, key=lambda item: (item[1], item[0])) + if distances.index(sorted_distances[0]) not in matched.keys(): + matched[distances.index(sorted_distances[0])] = [i] + else: + matched[distances.index(sorted_distances[0])].append(i) + return matched + + def get_pred_html(self, pred_structures, matched_index, ocr_contents): + end_html = [] + td_index = 0 + for tag in pred_structures: + if '' in tag: + if td_index in matched_index.keys(): + b_with = False + if '' in ocr_contents[matched_index[td_index][0]] and len(matched_index[td_index]) > 1: + b_with = True + end_html.extend('') + for i, td_index_index in enumerate(matched_index[td_index]): + content = ocr_contents[td_index_index][0] + if len(matched_index[td_index]) > 1: + if len(content) == 0: + continue + if content[0] == ' ': + content = content[1:] + if '' in content: + content = content[3:] + if '' in content: + content = content[:-4] + if len(content) == 0: + continue + if i != len(matched_index[td_index]) - 1 and ' ' != content[-1]: + content += ' ' + end_html.extend(content) + if b_with: + end_html.extend('') + + end_html.append(tag) + td_index += 1 + else: + end_html.append(tag) + return ''.join(end_html), end_html + + +def sorted_boxes(dt_boxes): + """ + Sort text boxes in order from top to bottom, left to right + args: + dt_boxes(array):detected text boxes with shape [4, 2] + return: + sorted boxes(array) with shape [4, 2] + """ + num_boxes = dt_boxes.shape[0] + sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) + _boxes = list(sorted_boxes) + + for i in range(num_boxes - 1): + if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \ + (_boxes[i + 1][0][0] < _boxes[i][0][0]): + tmp = _boxes[i] + _boxes[i] = _boxes[i + 1] + _boxes[i + 1] = tmp + return _boxes + + +def to_excel(html_table, excel_path): + from tablepyxl import tablepyxl + tablepyxl.document_to_xl(html_table, excel_path) + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + image_file_list = image_file_list[args.process_id::args.total_process_num] + os.makedirs(args.output, exist_ok=True) + + text_sys = TableSystem(args) + img_num = len(image_file_list) + for i, image_file in enumerate(image_file_list): + logger.info("[{}/{}] {}".format(i, img_num, image_file)) + img, flag = check_and_read_gif(image_file) + excel_path = os.path.join(args.table_output, os.path.basename(image_file).split('.')[0] + '.xlsx') + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.error("error in loading image:{}".format(image_file)) + continue + starttime = time.time() + pred_html = text_sys(img) + + to_excel(pred_html, excel_path) + logger.info('excel saved to {}'.format(excel_path)) + logger.info(pred_html) + elapse = time.time() - starttime + logger.info("Predict time : {:.3f}s".format(elapse)) + + +if __name__ == "__main__": + args = parse_args() + if args.use_mp: + p_list = [] + total_process_num = args.total_process_num + for process_id in range(total_process_num): + cmd = [sys.executable, "-u"] + sys.argv + [ + "--process_id={}".format(process_id), + "--use_mp={}".format(False) + ] + p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout) + p_list.append(p) + for p in p_list: + p.wait() + else: + main(args) diff --git a/ppstructure/table/table_metric/__init__.py b/ppstructure/table/table_metric/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..de2d307430f68881ece1e41357d3b2f423e07ddd --- /dev/null +++ b/ppstructure/table/table_metric/__init__.py @@ -0,0 +1,16 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['TEDS'] +from .table_metric import TEDS \ No newline at end of file diff --git a/ppstructure/table/table_metric/parallel.py b/ppstructure/table/table_metric/parallel.py new file mode 100755 index 0000000000000000000000000000000000000000..f7326a1f506ca5fb7b3e97b0d077dc016e7eb7c7 --- /dev/null +++ b/ppstructure/table/table_metric/parallel.py @@ -0,0 +1,51 @@ +from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor, as_completed + + +def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0): + """ + A parallel version of the map function with a progress bar. + Args: + array (array-like): An array to iterate over. + function (function): A python function to apply to the elements of array + n_jobs (int, default=16): The number of cores to use + use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of + keyword arguments to function + front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job. + Useful for catching bugs + Returns: + [function(array[0]), function(array[1]), ...] + """ + # We run the first few iterations serially to catch bugs + if front_num > 0: + front = [function(**a) if use_kwargs else function(a) + for a in array[:front_num]] + else: + front = [] + # If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging. + if n_jobs == 1: + return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])] + # Assemble the workers + with ProcessPoolExecutor(max_workers=n_jobs) as pool: + # Pass the elements of array into function + if use_kwargs: + futures = [pool.submit(function, **a) for a in array[front_num:]] + else: + futures = [pool.submit(function, a) for a in array[front_num:]] + kwargs = { + 'total': len(futures), + 'unit': 'it', + 'unit_scale': True, + 'leave': True + } + # Print out the progress as tasks complete + for f in tqdm(as_completed(futures), **kwargs): + pass + out = [] + # Get the results from the futures. + for i, future in tqdm(enumerate(futures)): + try: + out.append(future.result()) + except Exception as e: + out.append(e) + return front + out diff --git a/ppstructure/table/table_metric/table_metric.py b/ppstructure/table/table_metric/table_metric.py new file mode 100755 index 0000000000000000000000000000000000000000..9aca98ad785d4614a803fa5a277a6e4a27b3b078 --- /dev/null +++ b/ppstructure/table/table_metric/table_metric.py @@ -0,0 +1,247 @@ +# Copyright 2020 IBM +# Author: peter.zhong@au1.ibm.com +# +# This is free software; you can redistribute it and/or modify +# it under the terms of the Apache 2.0 License. +# +# This software is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Apache 2.0 License for more details. + +import distance +from apted import APTED, Config +from apted.helpers import Tree +from lxml import etree, html +from collections import deque +from .parallel import parallel_process +from tqdm import tqdm + + +class TableTree(Tree): + def __init__(self, tag, colspan=None, rowspan=None, content=None, *children): + self.tag = tag + self.colspan = colspan + self.rowspan = rowspan + self.content = content + self.children = list(children) + + def bracket(self): + """Show tree using brackets notation""" + if self.tag == 'td': + result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \ + (self.tag, self.colspan, self.rowspan, self.content) + else: + result = '"tag": %s' % self.tag + for child in self.children: + result += child.bracket() + return "{{{}}}".format(result) + + +class CustomConfig(Config): + @staticmethod + def maximum(*sequences): + """Get maximum possible value + """ + return max(map(len, sequences)) + + def normalized_distance(self, *sequences): + """Get distance from 0 to 1 + """ + return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) + + def rename(self, node1, node2): + """Compares attributes of trees""" + #print(node1.tag) + if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): + return 1. + if node1.tag == 'td': + if node1.content or node2.content: + #print(node1.content, ) + return self.normalized_distance(node1.content, node2.content) + return 0. + + + +class CustomConfig_del_short(Config): + @staticmethod + def maximum(*sequences): + """Get maximum possible value + """ + return max(map(len, sequences)) + + def normalized_distance(self, *sequences): + """Get distance from 0 to 1 + """ + return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) + + def rename(self, node1, node2): + """Compares attributes of trees""" + if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): + return 1. + if node1.tag == 'td': + if node1.content or node2.content: + #print('before') + #print(node1.content, node2.content) + #print('after') + node1_content = node1.content + node2_content = node2.content + if len(node1_content) < 3: + node1_content = ['####'] + if len(node2_content) < 3: + node2_content = ['####'] + return self.normalized_distance(node1_content, node2_content) + return 0. + +class CustomConfig_del_block(Config): + @staticmethod + def maximum(*sequences): + """Get maximum possible value + """ + return max(map(len, sequences)) + + def normalized_distance(self, *sequences): + """Get distance from 0 to 1 + """ + return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) + + def rename(self, node1, node2): + """Compares attributes of trees""" + if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): + return 1. + if node1.tag == 'td': + if node1.content or node2.content: + + node1_content = node1.content + node2_content = node2.content + while ' ' in node1_content: + print(node1_content.index(' ')) + node1_content.pop(node1_content.index(' ')) + while ' ' in node2_content: + print(node2_content.index(' ')) + node2_content.pop(node2_content.index(' ')) + return self.normalized_distance(node1_content, node2_content) + return 0. + +class TEDS(object): + ''' Tree Edit Distance basead Similarity + ''' + + def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None): + assert isinstance(n_jobs, int) and ( + n_jobs >= 1), 'n_jobs must be an integer greather than 1' + self.structure_only = structure_only + self.n_jobs = n_jobs + self.ignore_nodes = ignore_nodes + self.__tokens__ = [] + + def tokenize(self, node): + ''' Tokenizes table cells + ''' + self.__tokens__.append('<%s>' % node.tag) + if node.text is not None: + self.__tokens__ += list(node.text) + for n in node.getchildren(): + self.tokenize(n) + if node.tag != 'unk': + self.__tokens__.append('' % node.tag) + if node.tag != 'td' and node.tail is not None: + self.__tokens__ += list(node.tail) + + def load_html_tree(self, node, parent=None): + ''' Converts HTML tree to the format required by apted + ''' + global __tokens__ + if node.tag == 'td': + if self.structure_only: + cell = [] + else: + self.__tokens__ = [] + self.tokenize(node) + cell = self.__tokens__[1:-1].copy() + new_node = TableTree(node.tag, + int(node.attrib.get('colspan', '1')), + int(node.attrib.get('rowspan', '1')), + cell, *deque()) + else: + new_node = TableTree(node.tag, None, None, None, *deque()) + if parent is not None: + parent.children.append(new_node) + if node.tag != 'td': + for n in node.getchildren(): + self.load_html_tree(n, new_node) + if parent is None: + return new_node + + def evaluate(self, pred, true): + ''' Computes TEDS score between the prediction and the ground truth of a + given sample + ''' + if (not pred) or (not true): + return 0.0 + parser = html.HTMLParser(remove_comments=True, encoding='utf-8') + pred = html.fromstring(pred, parser=parser) + true = html.fromstring(true, parser=parser) + if pred.xpath('body/table') and true.xpath('body/table'): + pred = pred.xpath('body/table')[0] + true = true.xpath('body/table')[0] + if self.ignore_nodes: + etree.strip_tags(pred, *self.ignore_nodes) + etree.strip_tags(true, *self.ignore_nodes) + n_nodes_pred = len(pred.xpath(".//*")) + n_nodes_true = len(true.xpath(".//*")) + n_nodes = max(n_nodes_pred, n_nodes_true) + tree_pred = self.load_html_tree(pred) + tree_true = self.load_html_tree(true) + distance = APTED(tree_pred, tree_true, + CustomConfig()).compute_edit_distance() + return 1.0 - (float(distance) / n_nodes) + else: + return 0.0 + + def batch_evaluate(self, pred_json, true_json): + ''' Computes TEDS score between the prediction and the ground truth of + a batch of samples + @params pred_json: {'FILENAME': 'HTML CODE', ...} + @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...} + @output: {'FILENAME': 'TEDS SCORE', ...} + ''' + samples = true_json.keys() + if self.n_jobs == 1: + scores = [self.evaluate(pred_json.get( + filename, ''), true_json[filename]['html']) for filename in tqdm(samples)] + else: + inputs = [{'pred': pred_json.get( + filename, ''), 'true': true_json[filename]['html']} for filename in samples] + scores = parallel_process( + inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1) + scores = dict(zip(samples, scores)) + return scores + + def batch_evaluate_html(self, pred_htmls, true_htmls): + ''' Computes TEDS score between the prediction and the ground truth of + a batch of samples + ''' + if self.n_jobs == 1: + scores = [self.evaluate(pred_html, true_html) for ( + pred_html, true_html) in zip(pred_htmls, true_htmls)] + else: + inputs = [{"pred": pred_html, "true": true_html} for( + pred_html, true_html) in zip(pred_htmls, true_htmls)] + + scores = parallel_process( + inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1) + return scores + + +if __name__ == '__main__': + import json + import pprint + with open('sample_pred.json') as fp: + pred_json = json.load(fp) + with open('sample_gt.json') as fp: + true_json = json.load(fp) + teds = TEDS(n_jobs=4) + scores = teds.batch_evaluate(pred_json, true_json) + pp = pprint.PrettyPrinter() + pp.pprint(scores) diff --git a/ppstructure/table/tablepyxl/__init__.py b/ppstructure/table/tablepyxl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dc0085071cf4497b01fc648e7c38f2e8d9d173d0 --- /dev/null +++ b/ppstructure/table/tablepyxl/__init__.py @@ -0,0 +1,13 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/ppstructure/table/tablepyxl/style.py b/ppstructure/table/tablepyxl/style.py new file mode 100644 index 0000000000000000000000000000000000000000..ebd794b1b47d7f9e4f9294dde7330f592d613656 --- /dev/null +++ b/ppstructure/table/tablepyxl/style.py @@ -0,0 +1,283 @@ +# This is where we handle translating css styles into openpyxl styles +# and cascading those from parent to child in the dom. + +from openpyxl.cell import cell +from openpyxl.styles import Font, Alignment, PatternFill, NamedStyle, Border, Side, Color +from openpyxl.styles.fills import FILL_SOLID +from openpyxl.styles.numbers import FORMAT_CURRENCY_USD_SIMPLE, FORMAT_PERCENTAGE +from openpyxl.styles.colors import BLACK + +FORMAT_DATE_MMDDYYYY = 'mm/dd/yyyy' + + +def colormap(color): + """ + Convenience for looking up known colors + """ + cmap = {'black': BLACK} + return cmap.get(color, color) + + +def style_string_to_dict(style): + """ + Convert css style string to a python dictionary + """ + def clean_split(string, delim): + return (s.strip() for s in string.split(delim)) + styles = [clean_split(s, ":") for s in style.split(";") if ":" in s] + return dict(styles) + + +def get_side(style, name): + return {'border_style': style.get('border-{}-style'.format(name)), + 'color': colormap(style.get('border-{}-color'.format(name)))} + +known_styles = {} + + +def style_dict_to_named_style(style_dict, number_format=None): + """ + Change css style (stored in a python dictionary) to openpyxl NamedStyle + """ + + style_and_format_string = str({ + 'style_dict': style_dict, + 'parent': style_dict.parent, + 'number_format': number_format, + }) + + if style_and_format_string not in known_styles: + # Font + font = Font(bold=style_dict.get('font-weight') == 'bold', + color=style_dict.get_color('color', None), + size=style_dict.get('font-size')) + + # Alignment + alignment = Alignment(horizontal=style_dict.get('text-align', 'general'), + vertical=style_dict.get('vertical-align'), + wrap_text=style_dict.get('white-space', 'nowrap') == 'normal') + + # Fill + bg_color = style_dict.get_color('background-color') + fg_color = style_dict.get_color('foreground-color', Color()) + fill_type = style_dict.get('fill-type') + if bg_color and bg_color != 'transparent': + fill = PatternFill(fill_type=fill_type or FILL_SOLID, + start_color=bg_color, + end_color=fg_color) + else: + fill = PatternFill() + + # Border + border = Border(left=Side(**get_side(style_dict, 'left')), + right=Side(**get_side(style_dict, 'right')), + top=Side(**get_side(style_dict, 'top')), + bottom=Side(**get_side(style_dict, 'bottom')), + diagonal=Side(**get_side(style_dict, 'diagonal')), + diagonal_direction=None, + outline=Side(**get_side(style_dict, 'outline')), + vertical=None, + horizontal=None) + + name = 'Style {}'.format(len(known_styles) + 1) + + pyxl_style = NamedStyle(name=name, font=font, fill=fill, alignment=alignment, border=border, + number_format=number_format) + + known_styles[style_and_format_string] = pyxl_style + + return known_styles[style_and_format_string] + + +class StyleDict(dict): + """ + It's like a dictionary, but it looks for items in the parent dictionary + """ + def __init__(self, *args, **kwargs): + self.parent = kwargs.pop('parent', None) + super(StyleDict, self).__init__(*args, **kwargs) + + def __getitem__(self, item): + if item in self: + return super(StyleDict, self).__getitem__(item) + elif self.parent: + return self.parent[item] + else: + raise KeyError('{} not found'.format(item)) + + def __hash__(self): + return hash(tuple([(k, self.get(k)) for k in self._keys()])) + + # Yielding the keys avoids creating unnecessary data structures + # and happily works with both python2 and python3 where the + # .keys() method is a dictionary_view in python3 and a list in python2. + def _keys(self): + yielded = set() + for k in self.keys(): + yielded.add(k) + yield k + if self.parent: + for k in self.parent._keys(): + if k not in yielded: + yielded.add(k) + yield k + + def get(self, k, d=None): + try: + return self[k] + except KeyError: + return d + + def get_color(self, k, d=None): + """ + Strip leading # off colors if necessary + """ + color = self.get(k, d) + if hasattr(color, 'startswith') and color.startswith('#'): + color = color[1:] + if len(color) == 3: # Premailers reduces colors like #00ff00 to #0f0, openpyxl doesn't like that + color = ''.join(2 * c for c in color) + return color + + +class Element(object): + """ + Our base class for representing an html element along with a cascading style. + The element is created along with a parent so that the StyleDict that we store + can point to the parent's StyleDict. + """ + def __init__(self, element, parent=None): + self.element = element + self.number_format = None + parent_style = parent.style_dict if parent else None + self.style_dict = StyleDict(style_string_to_dict(element.get('style', '')), parent=parent_style) + self._style_cache = None + + def style(self): + """ + Turn the css styles for this element into an openpyxl NamedStyle. + """ + if not self._style_cache: + self._style_cache = style_dict_to_named_style(self.style_dict, number_format=self.number_format) + return self._style_cache + + def get_dimension(self, dimension_key): + """ + Extracts the dimension from the style dict of the Element and returns it as a float. + """ + dimension = self.style_dict.get(dimension_key) + if dimension: + if dimension[-2:] in ['px', 'em', 'pt', 'in', 'cm']: + dimension = dimension[:-2] + dimension = float(dimension) + return dimension + + +class Table(Element): + """ + The concrete implementations of Elements are semantically named for the types of elements we are interested in. + This defines a very concrete tree structure for html tables that we expect to deal with. I prefer this compared to + allowing Element to have an arbitrary number of children and dealing with an abstract element tree. + """ + def __init__(self, table): + """ + takes an html table object (from lxml) + """ + super(Table, self).__init__(table) + table_head = table.find('thead') + self.head = TableHead(table_head, parent=self) if table_head is not None else None + table_body = table.find('tbody') + self.body = TableBody(table_body if table_body is not None else table, parent=self) + + +class TableHead(Element): + """ + This class maps to the `` element of the html table. + """ + def __init__(self, head, parent=None): + super(TableHead, self).__init__(head, parent=parent) + self.rows = [TableRow(tr, parent=self) for tr in head.findall('tr')] + + +class TableBody(Element): + """ + This class maps to the `` element of the html table. + """ + def __init__(self, body, parent=None): + super(TableBody, self).__init__(body, parent=parent) + self.rows = [TableRow(tr, parent=self) for tr in body.findall('tr')] + + +class TableRow(Element): + """ + This class maps to the `` element of the html table. + """ + def __init__(self, tr, parent=None): + super(TableRow, self).__init__(tr, parent=parent) + self.cells = [TableCell(cell, parent=self) for cell in tr.findall('th') + tr.findall('td')] + + +def element_to_string(el): + return _element_to_string(el).strip() + + +def _element_to_string(el): + string = '' + + for x in el.iterchildren(): + string += '\n' + _element_to_string(x) + + text = el.text.strip() if el.text else '' + tail = el.tail.strip() if el.tail else '' + + return text + string + '\n' + tail + + +class TableCell(Element): + """ + This class maps to the `` element of the html table. + """ + CELL_TYPES = {'TYPE_STRING', 'TYPE_FORMULA', 'TYPE_NUMERIC', 'TYPE_BOOL', 'TYPE_CURRENCY', 'TYPE_PERCENTAGE', + 'TYPE_NULL', 'TYPE_INLINE', 'TYPE_ERROR', 'TYPE_FORMULA_CACHE_STRING', 'TYPE_INTEGER'} + + def __init__(self, cell, parent=None): + super(TableCell, self).__init__(cell, parent=parent) + self.value = element_to_string(cell) + self.number_format = self.get_number_format() + + def data_type(self): + cell_types = self.CELL_TYPES & set(self.element.get('class', '').split()) + if cell_types: + if 'TYPE_FORMULA' in cell_types: + # Make sure TYPE_FORMULA takes precedence over the other classes in the set. + cell_type = 'TYPE_FORMULA' + elif cell_types & {'TYPE_CURRENCY', 'TYPE_INTEGER', 'TYPE_PERCENTAGE'}: + cell_type = 'TYPE_NUMERIC' + else: + cell_type = cell_types.pop() + else: + cell_type = 'TYPE_STRING' + return getattr(cell, cell_type) + + def get_number_format(self): + if 'TYPE_CURRENCY' in self.element.get('class', '').split(): + return FORMAT_CURRENCY_USD_SIMPLE + if 'TYPE_INTEGER' in self.element.get('class', '').split(): + return '#,##0' + if 'TYPE_PERCENTAGE' in self.element.get('class', '').split(): + return FORMAT_PERCENTAGE + if 'TYPE_DATE' in self.element.get('class', '').split(): + return FORMAT_DATE_MMDDYYYY + if self.data_type() == cell.TYPE_NUMERIC: + try: + int(self.value) + except ValueError: + return '#,##0.##' + else: + return '#,##0' + + def format(self, cell): + cell.style = self.style() + data_type = self.data_type() + if data_type: + cell.data_type = data_type \ No newline at end of file diff --git a/ppstructure/table/tablepyxl/tablepyxl.py b/ppstructure/table/tablepyxl/tablepyxl.py new file mode 100644 index 0000000000000000000000000000000000000000..ba3cc0fc499fccd93ffe3993a99296bc6603ed8a --- /dev/null +++ b/ppstructure/table/tablepyxl/tablepyxl.py @@ -0,0 +1,118 @@ +# Do imports like python3 so our package works for 2 and 3 +from __future__ import absolute_import + +from lxml import html +from openpyxl import Workbook +from openpyxl.utils import get_column_letter +from premailer import Premailer +from tablepyxl.style import Table + + +def string_to_int(s): + if s.isdigit(): + return int(s) + return 0 + + +def get_Tables(doc): + tree = html.fromstring(doc) + comments = tree.xpath('//comment()') + for comment in comments: + comment.drop_tag() + return [Table(table) for table in tree.xpath('//table')] + + +def write_rows(worksheet, elem, row, column=1): + """ + Writes every tr child element of elem to a row in the worksheet + returns the next row after all rows are written + """ + from openpyxl.cell.cell import MergedCell + + initial_column = column + for table_row in elem.rows: + for table_cell in table_row.cells: + cell = worksheet.cell(row=row, column=column) + while isinstance(cell, MergedCell): + column += 1 + cell = worksheet.cell(row=row, column=column) + + colspan = string_to_int(table_cell.element.get("colspan", "1")) + rowspan = string_to_int(table_cell.element.get("rowspan", "1")) + if rowspan > 1 or colspan > 1: + worksheet.merge_cells(start_row=row, start_column=column, + end_row=row + rowspan - 1, end_column=column + colspan - 1) + + cell.value = table_cell.value + table_cell.format(cell) + min_width = table_cell.get_dimension('min-width') + max_width = table_cell.get_dimension('max-width') + + if colspan == 1: + # Initially, when iterating for the first time through the loop, the width of all the cells is None. + # As we start filling in contents, the initial width of the cell (which can be retrieved by: + # worksheet.column_dimensions[get_column_letter(column)].width) is equal to the width of the previous + # cell in the same column (i.e. width of A2 = width of A1) + width = max(worksheet.column_dimensions[get_column_letter(column)].width or 0, len(table_cell.value) + 2) + if max_width and width > max_width: + width = max_width + elif min_width and width < min_width: + width = min_width + worksheet.column_dimensions[get_column_letter(column)].width = width + column += colspan + row += 1 + column = initial_column + return row + + +def table_to_sheet(table, wb): + """ + Takes a table and workbook and writes the table to a new sheet. + The sheet title will be the same as the table attribute name. + """ + ws = wb.create_sheet(title=table.element.get('name')) + insert_table(table, ws, 1, 1) + + +def document_to_workbook(doc, wb=None, base_url=None): + """ + Takes a string representation of an html document and writes one sheet for + every table in the document. + The workbook is returned + """ + if not wb: + wb = Workbook() + wb.remove(wb.active) + + inline_styles_doc = Premailer(doc, base_url=base_url, remove_classes=False).transform() + tables = get_Tables(inline_styles_doc) + + for table in tables: + table_to_sheet(table, wb) + + return wb + + +def document_to_xl(doc, filename, base_url=None): + """ + Takes a string representation of an html document and writes one sheet for + every table in the document. The workbook is written out to a file called filename + """ + wb = document_to_workbook(doc, base_url=base_url) + wb.save(filename) + + +def insert_table(table, worksheet, column, row): + if table.head: + row = write_rows(worksheet, table.head, row, column) + if table.body: + row = write_rows(worksheet, table.body, row, column) + + +def insert_table_at_cell(table, cell): + """ + Inserts a table at the location of an openpyxl Cell object. + """ + ws = cell.parent + column, row = cell.column, cell.row + insert_table(table, ws, column, row) \ No newline at end of file diff --git a/ppstructure/utility.py b/ppstructure/utility.py new file mode 100644 index 0000000000000000000000000000000000000000..8112b9efd2155d69784ebc9915d9c3ec30e94f9c --- /dev/null +++ b/ppstructure/utility.py @@ -0,0 +1,59 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PIL import Image +import numpy as np +from tools.infer.utility import draw_ocr_box_txt, init_args as infer_args + + +def init_args(): + parser = infer_args() + + # params for output + parser.add_argument("--output", type=str, default='./output/table') + # params for table structure + parser.add_argument("--structure_max_len", type=int, default=488) + parser.add_argument("--structure_max_text_length", type=int, default=100) + parser.add_argument("--structure_max_elem_length", type=int, default=800) + parser.add_argument("--structure_max_cell_num", type=int, default=500) + parser.add_argument("--structure_model_dir", type=str) + parser.add_argument("--structure_char_type", type=str, default='en') + parser.add_argument("--structure_char_dict_path", type=str, default="../ppocr/utils/dict/table_structure_dict.txt") + + # params for layout detector + parser.add_argument("--layout_model_dir", type=str) + return parser + + +def parse_args(): + parser = init_args() + return parser.parse_args() + + +def draw_result(image, result, font_path): + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + boxes, txts, scores = [], [], [] + for region in result: + if region['type'] == 'Table': + pass + elif region['type'] == 'Figure': + pass + else: + for box, rec_res in zip(region['res'][0], region['res'][1]): + boxes.append(np.array(box).reshape(-1, 2)) + txts.append(rec_res[0]) + scores.append(rec_res[1]) + im_show = draw_ocr_box_txt(image, boxes, txts, scores, font_path=font_path,drop_score=0) + return im_show \ No newline at end of file diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index f5bade36315fbe321927df82cdd7cd8bf40b2ae5..baa89be130084d98628656fe4e309728a0e9f661 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -43,7 +43,7 @@ class TextDetector(object): pre_process_list = [{ 'DetResizeForTest': { 'limit_side_len': args.det_limit_side_len, - 'limit_type': args.det_limit_type + 'limit_type': args.det_limit_type, } }, { 'NormalizeImage': { diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 8f1b8f4812a6ae47a2456b7da9d456befa374183..58a363e3e8fa852ce37cd5a44a19e460da00c2bc 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -88,6 +88,8 @@ class TextSystem(object): ori_im = img.copy() dt_boxes, elapse = self.text_detector(img) + logger.debug("dt_boxes num : {}, elapse : {}".format( + len(dt_boxes), elapse)) if dt_boxes is None: return None, None img_crop_list = [] @@ -101,9 +103,13 @@ class TextSystem(object): if self.use_angle_cls and cls: img_crop_list, angle_list, elapse = self.text_classifier( img_crop_list) + logger.debug("cls num : {}, elapse : {}".format( + len(img_crop_list), elapse)) rec_res, elapse = self.text_recognizer(img_crop_list) - + logger.debug("rec_res num : {}, elapse : {}".format( + len(rec_res), elapse)) + # self.print_draw_crop_rec_res(img_crop_list, rec_res) filter_boxes, filter_rec_res = [], [] for box, rec_reuslt in zip(dt_boxes, rec_res): text, score = rec_reuslt @@ -149,7 +155,7 @@ def main(args): if not flag: img = cv2.imread(image_file) if img is None: - logger.info("error in loading image:{}".format(image_file)) + logger.error("error in loading image:{}".format(image_file)) continue starttime = time.time() dt_boxes, rec_res = text_sys(img) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 20fcfe4bc7475822a5e382f8908b289b614a9fe1..9210c45783029996f8d6fd105c8413d01c768806 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -109,10 +109,12 @@ def init_args(): parser.add_argument("--use_mp", type=str2bool, default=False) parser.add_argument("--total_process_num", type=int, default=1) parser.add_argument("--process_id", type=int, default=0) - + parser.add_argument("--benchmark", type=bool, default=False) parser.add_argument("--save_log_path", type=str, default="./log_output/") + parser.add_argument("--show_log", type=str2bool, default=True) + return parser @@ -198,6 +200,8 @@ def create_predictor(args, mode, logger): model_dir = args.cls_model_dir elif mode == 'rec': model_dir = args.rec_model_dir + elif mode == 'structure': + model_dir = args.structure_model_dir else: model_dir = args.e2e_model_dir @@ -327,7 +331,9 @@ def create_predictor(args, mode, logger): config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.switch_use_feed_fetch_ops(False) - + config.switch_ir_optim(True) + if mode == 'structure': + config.switch_ir_optim(False) # create predictor predictor = inference.create_predictor(config) input_names = predictor.get_input_names()