diff --git a/MANIFEST.in b/MANIFEST.in
index e16f157d6e9dd249d6c6a14ae54313759a6752c4..cd1c9636d4d23cc4d0f745403ec8ca407d1cc1a8 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,7 +1,7 @@
-include LICENSE.txt
+include LICENSE
include README.md
-recursive-include ppocr/utils *.txt utility.py logging.py
+recursive-include ppocr/utils *.txt utility.py logging.py network.py
recursive-include ppocr/data/ *.py
recursive-include ppocr/postprocess *.py
recursive-include tools/infer *.py
diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java
index 1c83e2184fe55aedd5022da839ab294b6bbe475c..b4ea34e2a38f91f3ecb1001c6bff3b71496b8f91 100644
--- a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java
+++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java
@@ -465,8 +465,12 @@ public class MainActivity extends AppCompatActivity {
}
public void btn_load_model_click(View view) {
- tvStatus.setText("STATUS: load model ......");
- loadModel();
+ if (predictor.isLoaded()){
+ tvStatus.setText("STATUS: model has been loaded");
+ }else{
+ tvStatus.setText("STATUS: load model ......");
+ loadModel();
+ }
}
public void btn_run_model_click(View view) {
diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java
index 1c294995c25b7eb3fa6ded17f41f193bddfc3886..b474d8886a10746b8ac181085c62481dfe7a4229 100644
--- a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java
+++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java
@@ -194,26 +194,25 @@ public class Predictor {
"supported!");
return false;
}
- int[] channelStride = new int[]{width * height, width * height * 2};
- int p = scaleImage.getPixel(scaleImage.getWidth() - 1, scaleImage.getHeight() - 1);
- for (int y = 0; y < height; y++) {
- for (int x = 0; x < width; x++) {
- int color = scaleImage.getPixel(x, y);
- float[] rgb = new float[]{(float) red(color) / 255.0f, (float) green(color) / 255.0f,
- (float) blue(color) / 255.0f};
- inputData[y * width + x] = (rgb[channelIdx[0]] - inputMean[0]) / inputStd[0];
- inputData[y * width + x + channelStride[0]] = (rgb[channelIdx[1]] - inputMean[1]) / inputStd[1];
- inputData[y * width + x + channelStride[1]] = (rgb[channelIdx[2]] - inputMean[2]) / inputStd[2];
- }
+ int[] channelStride = new int[]{width * height, width * height * 2};
+ int[] pixels=new int[width*height];
+ scaleImage.getPixels(pixels,0,scaleImage.getWidth(),0,0,scaleImage.getWidth(),scaleImage.getHeight());
+ for (int i = 0; i < pixels.length; i++) {
+ int color = pixels[i];
+ float[] rgb = new float[]{(float) red(color) / 255.0f, (float) green(color) / 255.0f,
+ (float) blue(color) / 255.0f};
+ inputData[i] = (rgb[channelIdx[0]] - inputMean[0]) / inputStd[0];
+ inputData[i + channelStride[0]] = (rgb[channelIdx[1]] - inputMean[1]) / inputStd[1];
+ inputData[i+ channelStride[1]] = (rgb[channelIdx[2]] - inputMean[2]) / inputStd[2];
}
} else if (channels == 1) {
- for (int y = 0; y < height; y++) {
- for (int x = 0; x < width; x++) {
- int color = inputImage.getPixel(x, y);
- float gray = (float) (red(color) + green(color) + blue(color)) / 3.0f / 255.0f;
- inputData[y * width + x] = (gray - inputMean[0]) / inputStd[0];
- }
+ int[] pixels=new int[width*height];
+ scaleImage.getPixels(pixels,0,scaleImage.getWidth(),0,0,scaleImage.getWidth(),scaleImage.getHeight());
+ for (int i = 0; i < pixels.length; i++) {
+ int color = pixels[i];
+ float gray = (float) (red(color) + green(color) + blue(color)) / 3.0f / 255.0f;
+ inputData[i] = (gray - inputMean[0]) / inputStd[0];
}
} else {
Log.i(TAG, "Unsupported channel size " + Integer.toString(channels) + ", only channel 1 and 3 is " +
diff --git a/doc/joinus.PNG b/doc/joinus.PNG
index 6e299a2ebe0eb52aa799ba9fa924bd685cd248de..4a274e631d8516789fca47e2db66bc0ac2d0f223 100644
Binary files a/doc/joinus.PNG and b/doc/joinus.PNG differ
diff --git a/doc/table/1.png b/doc/table/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..47df618ab1bef431a5dd94418c01be16b09d31aa
Binary files /dev/null and b/doc/table/1.png differ
diff --git a/paddleocr.py b/paddleocr.py
index 1e4d94ff4e72da951e1ffb92edb50715482581ae..48c8c9c6523dc3f813189477e641f0e51b740885 100644
--- a/paddleocr.py
+++ b/paddleocr.py
@@ -19,17 +19,16 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, ''))
import cv2
+import logging
import numpy as np
from pathlib import Path
-import tarfile
-import requests
-from tqdm import tqdm
from tools.infer import predict_system
from ppocr.utils.logging import get_logger
logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
+from ppocr.utils.network import maybe_download, download_with_progressbar
from tools.infer.utility import draw_ocr, init_args, str2bool
__all__ = ['PaddleOCR']
@@ -37,84 +36,84 @@ __all__ = ['PaddleOCR']
model_urls = {
'det': {
'ch':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
'en':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
},
'rec': {
'ch': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
},
'en': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/en_dict.txt'
},
'french': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/french_dict.txt'
},
'german': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/german_dict.txt'
},
'korean': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/korean_dict.txt'
},
'japan': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt'
},
'chinese_cht': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
},
'ta': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ta_dict.txt'
},
'te': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/te_dict.txt'
},
'ka': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ka_dict.txt'
},
'latin': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/latin_dict.txt'
},
'arabic': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/arabic_dict.txt'
},
'cyrillic': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
},
'devanagari': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
}
},
'cls':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
}
SUPPORT_DET_MODEL = ['DB']
@@ -123,50 +122,6 @@ SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/")
-def download_with_progressbar(url, save_path):
- response = requests.get(url, stream=True)
- total_size_in_bytes = int(response.headers.get('content-length', 0))
- block_size = 1024 # 1 Kibibyte
- progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
- with open(save_path, 'wb') as file:
- for data in response.iter_content(block_size):
- progress_bar.update(len(data))
- file.write(data)
- progress_bar.close()
- if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
- logger.error("Something went wrong while downloading models")
- sys.exit(0)
-
-
-def maybe_download(model_storage_directory, url):
- # using custom model
- tar_file_name_list = [
- 'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
- ]
- if not os.path.exists(
- os.path.join(model_storage_directory, 'inference.pdiparams')
- ) or not os.path.exists(
- os.path.join(model_storage_directory, 'inference.pdmodel')):
- tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
- print('download {} to {}'.format(url, tmp_path))
- os.makedirs(model_storage_directory, exist_ok=True)
- download_with_progressbar(url, tmp_path)
- with tarfile.open(tmp_path, 'r') as tarObj:
- for member in tarObj.getmembers():
- filename = None
- for tar_file_name in tar_file_name_list:
- if tar_file_name in member.name:
- filename = tar_file_name
- if filename is None:
- continue
- file = tarObj.extractfile(member)
- with open(
- os.path.join(model_storage_directory, filename),
- 'wb') as f:
- f.write(file.read())
- os.remove(tmp_path)
-
-
def parse_args(mMain=True):
import argparse
parser = init_args()
@@ -194,10 +149,12 @@ class PaddleOCR(predict_system.TextSystem):
args:
**kwargs: other params show in paddleocr --help
"""
- postprocess_params = parse_args(mMain=False)
- postprocess_params.__dict__.update(**kwargs)
- self.use_angle_cls = postprocess_params.use_angle_cls
- lang = postprocess_params.lang
+ params = parse_args(mMain=False)
+ params.__dict__.update(**kwargs)
+ if params.show_log:
+ logger.setLevel(logging.DEBUG)
+ self.use_angle_cls = params.use_angle_cls
+ lang = params.lang
latin_lang = [
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
@@ -223,46 +180,46 @@ class PaddleOCR(predict_system.TextSystem):
lang = "devanagari"
assert lang in model_urls[
'rec'], 'param lang must in {}, but got {}'.format(
- model_urls['rec'].keys(), lang)
+ model_urls['rec'].keys(), lang)
if lang == "ch":
det_lang = "ch"
else:
det_lang = "en"
use_inner_dict = False
- if postprocess_params.rec_char_dict_path is None:
+ if params.rec_char_dict_path is None:
use_inner_dict = True
- postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
+ params.rec_char_dict_path = model_urls['rec'][lang][
'dict_path']
# init model dir
- if postprocess_params.det_model_dir is None:
- postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION,
+ if params.det_model_dir is None:
+ params.det_model_dir = os.path.join(BASE_DIR, VERSION,
'det', det_lang)
- if postprocess_params.rec_model_dir is None:
- postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
+ if params.rec_model_dir is None:
+ params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
'rec', lang)
- if postprocess_params.cls_model_dir is None:
- postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
- print(postprocess_params)
+ if params.cls_model_dir is None:
+ params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
# download model
- maybe_download(postprocess_params.det_model_dir,
+ maybe_download(params.det_model_dir,
model_urls['det'][det_lang])
- maybe_download(postprocess_params.rec_model_dir,
+ maybe_download(params.rec_model_dir,
model_urls['rec'][lang]['url'])
- maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
+ maybe_download(params.cls_model_dir, model_urls['cls'])
- if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
+ if params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
sys.exit(0)
- if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL:
+ if params.rec_algorithm not in SUPPORT_REC_MODEL:
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
sys.exit(0)
if use_inner_dict:
- postprocess_params.rec_char_dict_path = str(
- Path(__file__).parent / postprocess_params.rec_char_dict_path)
+ params.rec_char_dict_path = str(
+ Path(__file__).parent / params.rec_char_dict_path)
+ print(params)
# init det_model and rec_model
- super().__init__(postprocess_params)
+ super().__init__(params)
def ocr(self, img, det=True, rec=True, cls=True):
"""
diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py
index 9c48b09647527cf718113ea1b5df152ff7befa04..2535b4420c503f2e9e9cc5a677ef70c4dd9c36be 100644
--- a/ppocr/data/imaug/operators.py
+++ b/ppocr/data/imaug/operators.py
@@ -81,7 +81,7 @@ class NormalizeImage(object):
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
- img.astype('float32') * self.scale - self.mean) / self.std
+ img.astype('float32') * self.scale - self.mean) / self.std
return data
@@ -163,7 +163,7 @@ class DetResizeForTest(object):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
- h, w, _ = img.shape
+ h, w, c = img.shape
# limit the max side
if self.limit_type == 'max':
@@ -174,7 +174,7 @@ class DetResizeForTest(object):
ratio = float(limit_side_len) / w
else:
ratio = 1.
- else:
+ elif self.limit_type == 'min':
if min(h, w) < limit_side_len:
if h < w:
ratio = float(limit_side_len) / h
@@ -182,6 +182,10 @@ class DetResizeForTest(object):
ratio = float(limit_side_len) / w
else:
ratio = 1.
+ elif self.limit_type == 'resize_long':
+ ratio = float(limit_side_len) / max(h,w)
+ else:
+ raise Exception('not support limit type, image ')
resize_h = int(h * ratio)
resize_w = int(w * ratio)
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index d353391c9af2b85bd01ba659f541fa1791461f68..85ce580f95b13539c6aeea32b188bfd3b435d140 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -44,16 +44,16 @@ class BaseRecLabelDecode(object):
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
elif character_type in support_character_type:
- self.character_str = ""
+ self.character_str = []
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
character_type)
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
- self.character_str += line
+ self.character_str.append(line)
if use_space_char:
- self.character_str += " "
+ self.character_str.append(" ")
dict_character = list(self.character_str)
else:
@@ -288,3 +288,156 @@ class SRNLabelDecode(BaseRecLabelDecode):
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
+
+
+class TableLabelDecode(object):
+ """ """
+
+ def __init__(self,
+ max_text_length,
+ max_elem_length,
+ max_cell_num,
+ character_dict_path,
+ **kwargs):
+ self.max_text_length = max_text_length
+ self.max_elem_length = max_elem_length
+ self.max_cell_num = max_cell_num
+ list_character, list_elem = self.load_char_elem_dict(character_dict_path)
+ list_character = self.add_special_char(list_character)
+ list_elem = self.add_special_char(list_elem)
+ self.dict_character = {}
+ self.dict_idx_character = {}
+ for i, char in enumerate(list_character):
+ self.dict_idx_character[i] = char
+ self.dict_character[char] = i
+ self.dict_elem = {}
+ self.dict_idx_elem = {}
+ for i, elem in enumerate(list_elem):
+ self.dict_idx_elem[i] = elem
+ self.dict_elem[elem] = i
+
+ def load_char_elem_dict(self, character_dict_path):
+ list_character = []
+ list_elem = []
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ substr = lines[0].decode('utf-8').strip("\n").split("\t")
+ character_num = int(substr[0])
+ elem_num = int(substr[1])
+ for cno in range(1, 1 + character_num):
+ character = lines[cno].decode('utf-8').strip("\n")
+ list_character.append(character)
+ for eno in range(1 + character_num, 1 + character_num + elem_num):
+ elem = lines[eno].decode('utf-8').strip("\n")
+ list_elem.append(elem)
+ return list_character, list_elem
+
+ def add_special_char(self, list_character):
+ self.beg_str = "sos"
+ self.end_str = "eos"
+ list_character = [self.beg_str] + list_character + [self.end_str]
+ return list_character
+
+ def get_sp_tokens(self):
+ char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
+ char_end_idx = self.get_beg_end_flag_idx('end', 'char')
+ elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
+ elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
+ elem_char_idx1 = self.dict_elem['
']
+ elem_char_idx2 = self.dict_elem[' | ', ' | 0 and tmp_elem_idx == end_idx:
+ break
+ if tmp_elem_idx in ignored_tokens:
+ continue
+
+ char_list.append(current_dict[tmp_elem_idx])
+ elem_pos_list.append(idx)
+ score_list.append(structure_probs[batch_idx, idx])
+ elem_idx_list.append(tmp_elem_idx)
+ result_list.append(char_list)
+ result_pos_list.append(elem_pos_list)
+ result_score_list.append(score_list)
+ result_elem_idx_list.append(elem_idx_list)
+ return result_list, result_pos_list, result_score_list, result_elem_idx_list
+
+ def get_ignored_tokens(self, char_or_elem):
+ beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
+ end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
+ if char_or_elem == "char":
+ if beg_or_end == "beg":
+ idx = self.dict_character[self.beg_str]
+ elif beg_or_end == "end":
+ idx = self.dict_character[self.end_str]
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
+ % beg_or_end
+ elif char_or_elem == "elem":
+ if beg_or_end == "beg":
+ idx = self.dict_elem[self.beg_str]
+ elif beg_or_end == "end":
+ idx = self.dict_elem[self.end_str]
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
+ % beg_or_end
+ else:
+ assert False, "Unsupport type %s in char_or_elem" \
+ % char_or_elem
+ return idx
diff --git a/ppocr/utils/dict/table_dict.txt b/ppocr/utils/dict/table_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2ef028c786cbce6d1e25856c62986d757b31f93b
--- /dev/null
+++ b/ppocr/utils/dict/table_dict.txt
@@ -0,0 +1,277 @@
+←
+
+☆
+─
+α
+
+
+⋅
+$
+ω
+ψ
+χ
+(
+υ
+≥
+σ
+,
+ρ
+ε
+0
+■
+4
+8
+✗
+b
+<
+✓
+Ψ
+Ω
+€
+D
+3
+Π
+H
+║
+
+L
+Φ
+Χ
+θ
+P
+κ
+λ
+μ
+T
+ξ
+X
+β
+γ
+δ
+\
+ζ
+η
+`
+d
+
+h
+f
+l
+Θ
+p
+√
+t
+
+x
+Β
+Γ
+Δ
+|
+ǂ
+ɛ
+j
+̧
+➢
+
+̌
+′
+«
+△
+▲
+#
+
+'
+Ι
++
+¶
+/
+▼
+⇑
+□
+·
+7
+▪
+;
+?
+➔
+∩
+C
+÷
+G
+⇒
+K
+
+O
+S
+С
+W
+Α
+[
+○
+_
+●
+‡
+c
+z
+g
+
+o
+
+〈
+〉
+s
+⩽
+w
+φ
+ʹ
+{
+»
+∣
+̆
+e
+ˆ
+∈
+τ
+◆
+ι
+∅
+∆
+∙
+∘
+Ø
+ß
+✔
+∞
+∑
+−
+×
+◊
+∗
+∖
+˃
+˂
+∫
+"
+i
+&
+π
+↔
+*
+∥
+æ
+∧
+.
+⁄
+ø
+Q
+∼
+6
+⁎
+:
+★
+>
+a
+B
+≈
+F
+J
+̄
+N
+♯
+R
+V
+
+―
+Z
+♣
+^
+¤
+¥
+§
+
+¢
+£
+≦
+
+≤
+‖
+Λ
+©
+n
+↓
+→
+↑
+r
+°
+±
+v
+
+♂
+k
+♀
+~
+ᅟ
+̇
+@
+”
+♦
+ł
+®
+⊕
+„
+!
+
+%
+⇓
+)
+-
+1
+5
+9
+=
+А
+A
+‰
+⋆
+Σ
+E
+◦
+I
+※
+M
+m
+̨
+⩾
+†
+
+•
+U
+Y
+
+]
+̸
+2
+‐
+–
+‒
+̂
+—
+̀
+́
+’
+‘
+⋮
+⋯
+̊
+“
+̈
+≧
+q
+u
+ı
+y
+
+
+̃
+}
+ν
diff --git a/ppocr/utils/dict/table_structure_dict.txt b/ppocr/utils/dict/table_structure_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9c4531e5f3b8c498e70d3c2ea0471e5e746a2c30
--- /dev/null
+++ b/ppocr/utils/dict/table_structure_dict.txt
@@ -0,0 +1,2759 @@
+277 28 1267 1186
+
+V
+a
+r
+i
+b
+l
+e
+
+H
+z
+d
+
+t
+o
+9
+5
+%
+C
+I
+
+p
+
+v
+u
+*
+A
+g
+(
+m
+n
+)
+0
+.
+7
+1
+6
+≤
+>
+8
+3
+–
+2
+G
+4
+M
+F
+T
+y
+f
+s
+L
+w
+c
+U
+h
+D
+S
+Q
+R
+x
+P
+-
+E
+O
+/
+k
+,
++
+N
+K
+q
+′
+[
+]
+<
+≥
+
+−
+
+μ
+±
+J
+j
+W
+_
+Δ
+B
+“
+:
+Y
+α
+λ
+;
+
+
+?
+∼
+=
+°
+#
+̊
+̈
+̂
+’
+Z
+X
+∗
+—
+β
+'
+†
+~
+@
+"
+γ
+↓
+↑
+&
+‡
+χ
+”
+σ
+§
+|
+¶
+‐
+×
+$
+→
+√
+✓
+‘
+\
+∞
+π
+•
+®
+^
+∆
+≧
+
+
+́
+♀
+♂
+‒
+⁎
+▲
+·
+£
+φ
+Ψ
+ß
+△
+☆
+▪
+η
+€
+∧
+̃
+Φ
+ρ
+̄
+δ
+‰
+̧
+Ω
+♦
+{
+}
+̀
+∑
+∫
+ø
+κ
+ε
+¥
+※
+`
+ω
+Σ
+➔
+‖
+Β
+̸
+
+─
+●
+⩾
+Χ
+Α
+⋅
+◆
+★
+■
+ψ
+ǂ
+□
+ζ
+!
+Γ
+↔
+θ
+⁄
+〈
+〉
+―
+υ
+τ
+⋆
+Ø
+©
+∥
+С
+˂
+➢
+ɛ
+
+✗
+←
+○
+¢
+⩽
+∖
+˃
+
+≈
+Π
+̌
+≦
+∅
+ᅟ
+
+
+∣
+¤
+♯
+̆
+ξ
+÷
+▼
+
+ι
+ν
+║
+
+
+◦
+
+◊
+∙
+«
+»
+ł
+ı
+Θ
+∈
+„
+∘
+✔
+̇
+æ
+ʹ
+ˆ
+♣
+⇓
+∩
+⊕
+⇒
+⇑
+̨
+Ι
+Λ
+⋯
+А
+⋮
+
+
+
+ |
+
+
+
+
+
+ colspan="2"
+ colspan="3"
+ rowspan="2"
+ colspan="4"
+ colspan="6"
+ rowspan="3"
+ colspan="9"
+ colspan="10"
+ colspan="7"
+ rowspan="4"
+ rowspan="5"
+ rowspan="9"
+ colspan="8"
+ rowspan="8"
+ rowspan="6"
+ rowspan="7"
+ rowspan="10"
+0 2924682
+1 3405345
+2 2363468
+3 2709165
+4 4078680
+5 3250792
+6 1923159
+7 1617890
+8 1450532
+9 1717624
+10 1477550
+11 1489223
+12 915528
+13 819193
+14 593660
+15 518924
+16 682065
+17 494584
+18 400591
+19 396421
+20 340994
+21 280688
+22 250328
+23 226786
+24 199927
+25 182707
+26 164629
+27 141613
+28 127554
+29 116286
+30 107682
+31 96367
+32 88002
+33 79234
+34 72186
+35 65921
+36 60374
+37 55976
+38 52166
+39 47414
+40 44932
+41 41279
+42 38232
+43 35463
+44 33703
+45 30557
+46 29639
+47 27000
+48 25447
+49 23186
+50 22093
+51 20412
+52 19844
+53 18261
+54 17561
+55 16499
+56 15597
+57 14558
+58 14372
+59 13445
+60 13514
+61 12058
+62 11145
+63 10767
+64 10370
+65 9630
+66 9337
+67 8881
+68 8727
+69 8060
+70 7994
+71 7740
+72 7189
+73 6729
+74 6749
+75 6548
+76 6321
+77 5957
+78 5740
+79 5407
+80 5370
+81 5035
+82 4921
+83 4656
+84 4600
+85 4519
+86 4277
+87 4023
+88 3939
+89 3910
+90 3861
+91 3560
+92 3483
+93 3406
+94 3346
+95 3229
+96 3122
+97 3086
+98 3001
+99 2884
+100 2822
+101 2677
+102 2670
+103 2610
+104 2452
+105 2446
+106 2400
+107 2300
+108 2316
+109 2196
+110 2089
+111 2083
+112 2041
+113 1881
+114 1838
+115 1896
+116 1795
+117 1786
+118 1743
+119 1765
+120 1750
+121 1683
+122 1563
+123 1499
+124 1513
+125 1462
+126 1388
+127 1441
+128 1417
+129 1392
+130 1306
+131 1321
+132 1274
+133 1294
+134 1240
+135 1126
+136 1157
+137 1130
+138 1084
+139 1130
+140 1083
+141 1040
+142 980
+143 1031
+144 974
+145 980
+146 932
+147 898
+148 960
+149 907
+150 852
+151 912
+152 859
+153 847
+154 876
+155 792
+156 791
+157 765
+158 788
+159 787
+160 744
+161 673
+162 683
+163 697
+164 666
+165 680
+166 632
+167 677
+168 657
+169 618
+170 587
+171 585
+172 567
+173 549
+174 562
+175 548
+176 542
+177 539
+178 542
+179 549
+180 547
+181 526
+182 525
+183 514
+184 512
+185 505
+186 515
+187 467
+188 475
+189 458
+190 435
+191 443
+192 427
+193 424
+194 404
+195 389
+196 429
+197 404
+198 386
+199 351
+200 388
+201 408
+202 361
+203 346
+204 324
+205 361
+206 363
+207 364
+208 323
+209 336
+210 342
+211 315
+212 325
+213 328
+214 314
+215 327
+216 320
+217 300
+218 295
+219 315
+220 310
+221 295
+222 275
+223 248
+224 274
+225 232
+226 293
+227 259
+228 286
+229 263
+230 242
+231 214
+232 261
+233 231
+234 211
+235 250
+236 233
+237 206
+238 224
+239 210
+240 233
+241 223
+242 216
+243 222
+244 207
+245 212
+246 196
+247 205
+248 201
+249 202
+250 211
+251 201
+252 215
+253 179
+254 163
+255 179
+256 191
+257 188
+258 196
+259 150
+260 154
+261 176
+262 211
+263 166
+264 171
+265 165
+266 149
+267 182
+268 159
+269 161
+270 164
+271 161
+272 141
+273 151
+274 127
+275 129
+276 142
+277 158
+278 148
+279 135
+280 127
+281 134
+282 138
+283 131
+284 126
+285 125
+286 130
+287 126
+288 135
+289 125
+290 135
+291 131
+292 95
+293 135
+294 106
+295 117
+296 136
+297 128
+298 128
+299 118
+300 109
+301 112
+302 117
+303 108
+304 120
+305 100
+306 95
+307 108
+308 112
+309 77
+310 120
+311 104
+312 109
+313 89
+314 98
+315 82
+316 98
+317 93
+318 77
+319 93
+320 77
+321 98
+322 93
+323 86
+324 89
+325 73
+326 70
+327 71
+328 77
+329 87
+330 77
+331 93
+332 100
+333 83
+334 72
+335 74
+336 69
+337 77
+338 68
+339 78
+340 90
+341 98
+342 75
+343 80
+344 63
+345 71
+346 83
+347 66
+348 71
+349 70
+350 62
+351 62
+352 59
+353 63
+354 62
+355 52
+356 64
+357 64
+358 56
+359 49
+360 57
+361 63
+362 60
+363 68
+364 62
+365 55
+366 54
+367 40
+368 75
+369 70
+370 53
+371 58
+372 57
+373 55
+374 69
+375 57
+376 53
+377 43
+378 45
+379 47
+380 56
+381 51
+382 59
+383 51
+384 43
+385 34
+386 57
+387 49
+388 39
+389 46
+390 48
+391 43
+392 40
+393 54
+394 50
+395 41
+396 43
+397 33
+398 27
+399 49
+400 44
+401 44
+402 38
+403 30
+404 32
+405 37
+406 39
+407 42
+408 53
+409 39
+410 34
+411 31
+412 32
+413 52
+414 27
+415 41
+416 34
+417 36
+418 50
+419 35
+420 32
+421 33
+422 45
+423 35
+424 40
+425 29
+426 41
+427 40
+428 39
+429 32
+430 31
+431 34
+432 29
+433 27
+434 26
+435 22
+436 34
+437 28
+438 30
+439 38
+440 35
+441 36
+442 36
+443 27
+444 24
+445 33
+446 31
+447 25
+448 33
+449 27
+450 32
+451 46
+452 31
+453 35
+454 35
+455 34
+456 26
+457 21
+458 25
+459 26
+460 24
+461 27
+462 33
+463 30
+464 35
+465 21
+466 32
+467 19
+468 27
+469 16
+470 28
+471 26
+472 27
+473 26
+474 25
+475 25
+476 27
+477 20
+478 28
+479 22
+480 23
+481 16
+482 25
+483 27
+484 19
+485 23
+486 19
+487 15
+488 15
+489 23
+490 24
+491 19
+492 20
+493 18
+494 17
+495 30
+496 28
+497 20
+498 29
+499 17
+500 19
+501 21
+502 15
+503 24
+504 15
+505 19
+506 25
+507 16
+508 23
+509 26
+510 21
+511 15
+512 12
+513 16
+514 18
+515 24
+516 26
+517 18
+518 8
+519 25
+520 14
+521 8
+522 24
+523 20
+524 18
+525 15
+526 13
+527 17
+528 18
+529 22
+530 21
+531 9
+532 16
+533 17
+534 13
+535 17
+536 15
+537 13
+538 20
+539 13
+540 19
+541 29
+542 10
+543 8
+544 18
+545 13
+546 9
+547 18
+548 10
+549 18
+550 18
+551 9
+552 9
+553 15
+554 13
+555 15
+556 14
+557 14
+558 18
+559 8
+560 13
+561 9
+562 7
+563 12
+564 6
+565 9
+566 9
+567 18
+568 9
+569 10
+570 13
+571 14
+572 13
+573 21
+574 8
+575 16
+576 12
+577 9
+578 16
+579 17
+580 22
+581 6
+582 14
+583 13
+584 15
+585 11
+586 13
+587 5
+588 12
+589 13
+590 15
+591 13
+592 15
+593 12
+594 7
+595 18
+596 12
+597 13
+598 13
+599 13
+600 12
+601 12
+602 10
+603 11
+604 6
+605 6
+606 2
+607 9
+608 8
+609 12
+610 9
+611 12
+612 13
+613 12
+614 14
+615 9
+616 8
+617 9
+618 14
+619 13
+620 12
+621 6
+622 8
+623 8
+624 8
+625 12
+626 8
+627 7
+628 5
+629 8
+630 12
+631 6
+632 10
+633 10
+634 7
+635 8
+636 9
+637 6
+638 9
+639 4
+640 12
+641 4
+642 3
+643 11
+644 10
+645 6
+646 12
+647 12
+648 4
+649 4
+650 9
+651 8
+652 6
+653 5
+654 14
+655 10
+656 11
+657 8
+658 5
+659 5
+660 9
+661 13
+662 4
+663 5
+664 9
+665 11
+666 12
+667 7
+668 13
+669 2
+670 1
+671 7
+672 7
+673 7
+674 10
+675 9
+676 6
+677 5
+678 7
+679 6
+680 3
+681 3
+682 4
+683 9
+684 8
+685 5
+686 3
+687 11
+688 9
+689 2
+690 6
+691 5
+692 9
+693 5
+694 6
+695 5
+696 9
+697 8
+698 3
+699 7
+700 5
+701 9
+702 8
+703 7
+704 2
+705 3
+706 7
+707 6
+708 6
+709 10
+710 2
+711 10
+712 6
+713 7
+714 5
+715 6
+716 4
+717 6
+718 8
+719 4
+720 6
+721 7
+722 5
+723 7
+724 3
+725 10
+726 10
+727 3
+728 7
+729 7
+730 5
+731 2
+732 1
+733 5
+734 1
+735 5
+736 6
+737 2
+738 2
+739 3
+740 7
+741 2
+742 7
+743 4
+744 5
+745 4
+746 5
+747 3
+748 1
+749 4
+750 4
+751 2
+752 4
+753 6
+754 6
+755 6
+756 3
+757 2
+758 5
+759 5
+760 3
+761 4
+762 2
+763 1
+764 8
+765 3
+766 4
+767 3
+768 1
+769 5
+770 3
+771 3
+772 4
+773 4
+774 1
+775 3
+776 2
+777 2
+778 3
+779 3
+780 1
+781 4
+782 3
+783 4
+784 6
+785 3
+786 5
+787 4
+788 2
+789 4
+790 5
+791 4
+792 6
+794 4
+795 1
+796 1
+797 4
+798 2
+799 3
+800 3
+801 1
+802 5
+803 5
+804 3
+805 3
+806 3
+807 4
+808 4
+809 2
+811 5
+812 4
+813 6
+814 3
+815 2
+816 2
+817 3
+818 5
+819 3
+820 1
+821 1
+822 4
+823 3
+824 4
+825 8
+826 3
+827 5
+828 5
+829 3
+830 6
+831 3
+832 4
+833 8
+834 5
+835 3
+836 3
+837 2
+838 4
+839 2
+840 1
+841 3
+842 2
+843 1
+844 3
+846 4
+847 4
+848 3
+849 3
+850 2
+851 3
+853 1
+854 4
+855 4
+856 2
+857 4
+858 1
+859 2
+860 5
+861 1
+862 1
+863 4
+864 2
+865 2
+867 5
+868 1
+869 4
+870 1
+871 1
+872 1
+873 2
+875 5
+876 3
+877 1
+878 3
+879 3
+880 3
+881 2
+882 1
+883 6
+884 2
+885 2
+886 1
+887 1
+888 3
+889 2
+890 2
+891 3
+892 1
+893 3
+894 1
+895 5
+896 1
+897 3
+899 2
+900 2
+902 1
+903 2
+904 4
+905 4
+906 3
+907 1
+908 1
+909 2
+910 5
+911 2
+912 3
+914 1
+915 1
+916 2
+918 2
+919 2
+920 4
+921 4
+922 1
+923 1
+924 4
+925 5
+926 1
+928 2
+929 1
+930 1
+931 1
+932 1
+933 1
+934 2
+935 1
+936 1
+937 1
+938 2
+939 1
+941 1
+942 4
+944 2
+945 2
+946 2
+947 1
+948 1
+950 1
+951 2
+953 1
+954 2
+955 1
+956 1
+957 2
+958 1
+960 3
+962 4
+963 1
+964 1
+965 3
+966 2
+967 2
+968 1
+969 3
+970 3
+972 1
+974 4
+975 3
+976 3
+977 2
+979 2
+980 1
+981 1
+983 5
+984 1
+985 3
+986 1
+987 2
+988 4
+989 2
+991 2
+992 2
+993 1
+994 1
+996 2
+997 2
+998 1
+999 3
+1000 2
+1001 1
+1002 3
+1003 3
+1004 2
+1005 3
+1006 1
+1007 2
+1009 1
+1011 1
+1013 3
+1014 1
+1016 2
+1017 1
+1018 1
+1019 1
+1020 4
+1021 1
+1022 2
+1025 1
+1026 1
+1027 2
+1028 1
+1030 1
+1031 2
+1032 4
+1034 3
+1035 2
+1036 1
+1038 1
+1039 1
+1040 1
+1041 1
+1042 2
+1043 1
+1044 2
+1045 4
+1048 1
+1050 1
+1051 1
+1052 2
+1054 1
+1055 3
+1056 2
+1057 1
+1059 1
+1061 2
+1063 1
+1064 1
+1065 1
+1066 1
+1067 1
+1068 1
+1069 2
+1074 1
+1075 1
+1077 1
+1078 1
+1079 1
+1082 1
+1085 1
+1088 1
+1090 1
+1091 1
+1092 2
+1094 2
+1097 2
+1098 1
+1099 2
+1101 2
+1102 1
+1104 1
+1105 1
+1107 1
+1109 1
+1111 2
+1112 1
+1114 2
+1115 2
+1116 2
+1117 1
+1118 1
+1119 1
+1120 1
+1122 1
+1123 1
+1127 1
+1128 3
+1132 2
+1138 3
+1142 1
+1145 4
+1150 1
+1153 2
+1154 1
+1158 1
+1159 1
+1163 1
+1165 1
+1169 2
+1174 1
+1176 1
+1177 1
+1178 2
+1179 1
+1180 2
+1181 1
+1182 1
+1183 2
+1185 1
+1187 1
+1191 2
+1193 1
+1195 3
+1196 1
+1201 3
+1203 1
+1206 1
+1210 1
+1213 1
+1214 1
+1215 2
+1218 1
+1220 1
+1221 1
+1225 1
+1226 1
+1233 2
+1241 1
+1243 1
+1249 1
+1250 2
+1251 1
+1254 1
+1255 2
+1260 1
+1268 1
+1270 1
+1273 1
+1274 1
+1277 1
+1284 1
+1287 1
+1291 1
+1292 2
+1294 1
+1295 2
+1297 1
+1298 1
+1301 1
+1307 1
+1308 3
+1311 2
+1313 1
+1316 1
+1321 1
+1324 1
+1325 1
+1330 1
+1333 1
+1334 1
+1338 2
+1340 1
+1341 1
+1342 1
+1343 1
+1345 1
+1355 1
+1357 1
+1360 2
+1375 1
+1376 1
+1380 1
+1383 1
+1387 1
+1389 1
+1393 1
+1394 1
+1396 1
+1398 1
+1410 1
+1414 1
+1419 1
+1425 1
+1434 1
+1435 1
+1438 1
+1439 1
+1447 1
+1455 2
+1460 1
+1461 1
+1463 1
+1466 1
+1470 1
+1473 1
+1478 1
+1480 1
+1483 1
+1484 1
+1485 2
+1492 2
+1499 1
+1509 1
+1512 1
+1513 1
+1523 1
+1524 1
+1525 2
+1529 1
+1539 1
+1544 1
+1568 1
+1584 1
+1591 1
+1598 1
+1600 1
+1604 1
+1614 1
+1617 1
+1621 1
+1622 1
+1626 1
+1638 1
+1648 1
+1658 1
+1661 1
+1679 1
+1682 1
+1693 1
+1700 1
+1705 1
+1707 1
+1722 1
+1728 1
+1758 1
+1762 1
+1763 1
+1775 1
+1776 1
+1801 1
+1810 1
+1812 1
+1827 1
+1834 1
+1846 1
+1847 1
+1848 1
+1851 1
+1862 1
+1866 1
+1877 2
+1884 1
+1888 1
+1903 1
+1912 1
+1925 1
+1938 1
+1955 1
+1998 1
+2054 1
+2058 1
+2065 1
+2069 1
+2076 1
+2089 1
+2104 1
+2111 1
+2133 1
+2138 1
+2156 1
+2204 1
+2212 1
+2237 1
+2246 2
+2298 1
+2304 1
+2360 1
+2400 1
+2481 1
+2544 1
+2586 1
+2622 1
+2666 1
+2682 1
+2725 1
+2920 1
+3997 1
+4019 1
+5211 1
+12 19
+14 1
+16 401
+18 2
+20 421
+22 557
+24 625
+26 50
+28 4481
+30 52
+32 550
+34 5840
+36 4644
+38 87
+40 5794
+41 33
+42 571
+44 11805
+46 4711
+47 7
+48 597
+49 12
+50 678
+51 2
+52 14715
+53 3
+54 7322
+55 3
+56 508
+57 39
+58 3486
+59 11
+60 8974
+61 45
+62 1276
+63 4
+64 15693
+65 15
+66 657
+67 13
+68 6409
+69 10
+70 3188
+71 25
+72 1889
+73 27
+74 10370
+75 9
+76 12432
+77 23
+78 520
+79 15
+80 1534
+81 29
+82 2944
+83 23
+84 12071
+85 36
+86 1502
+87 10
+88 10978
+89 11
+90 889
+91 16
+92 4571
+93 17
+94 7855
+95 21
+96 2271
+97 33
+98 1423
+99 15
+100 11096
+101 21
+102 4082
+103 13
+104 5442
+105 25
+106 2113
+107 26
+108 3779
+109 43
+110 1294
+111 29
+112 7860
+113 29
+114 4965
+115 22
+116 7898
+117 25
+118 1772
+119 28
+120 1149
+121 38
+122 1483
+123 32
+124 10572
+125 25
+126 1147
+127 31
+128 1699
+129 22
+130 5533
+131 22
+132 4669
+133 34
+134 3777
+135 10
+136 5412
+137 21
+138 855
+139 26
+140 2485
+141 46
+142 1970
+143 27
+144 6565
+145 40
+146 933
+147 15
+148 7923
+149 16
+150 735
+151 23
+152 1111
+153 33
+154 3714
+155 27
+156 2445
+157 30
+158 3367
+159 10
+160 4646
+161 27
+162 990
+163 23
+164 5679
+165 25
+166 2186
+167 17
+168 899
+169 32
+170 1034
+171 22
+172 6185
+173 32
+174 2685
+175 17
+176 1354
+177 38
+178 1460
+179 15
+180 3478
+181 20
+182 958
+183 20
+184 6055
+185 23
+186 2180
+187 15
+188 1416
+189 30
+190 1284
+191 22
+192 1341
+193 21
+194 2413
+195 18
+196 4984
+197 13
+198 830
+199 22
+200 1834
+201 19
+202 2238
+203 9
+204 3050
+205 22
+206 616
+207 17
+208 2892
+209 22
+210 711
+211 30
+212 2631
+213 19
+214 3341
+215 21
+216 987
+217 26
+218 823
+219 9
+220 3588
+221 20
+222 692
+223 7
+224 2925
+225 31
+226 1075
+227 16
+228 2909
+229 18
+230 673
+231 20
+232 2215
+233 14
+234 1584
+235 21
+236 1292
+237 29
+238 1647
+239 25
+240 1014
+241 30
+242 1648
+243 19
+244 4465
+245 10
+246 787
+247 11
+248 480
+249 25
+250 842
+251 15
+252 1219
+253 23
+254 1508
+255 8
+256 3525
+257 16
+258 490
+259 12
+260 1678
+261 14
+262 822
+263 16
+264 1729
+265 28
+266 604
+267 11
+268 2572
+269 7
+270 1242
+271 15
+272 725
+273 18
+274 1983
+275 13
+276 1662
+277 19
+278 491
+279 12
+280 1586
+281 14
+282 563
+283 10
+284 2363
+285 10
+286 656
+287 14
+288 725
+289 28
+290 871
+291 9
+292 2606
+293 12
+294 961
+295 9
+296 478
+297 13
+298 1252
+299 10
+300 736
+301 19
+302 466
+303 13
+304 2254
+305 12
+306 486
+307 14
+308 1145
+309 13
+310 955
+311 13
+312 1235
+313 13
+314 931
+315 14
+316 1768
+317 11
+318 330
+319 10
+320 539
+321 23
+322 570
+323 12
+324 1789
+325 13
+326 884
+327 5
+328 1422
+329 14
+330 317
+331 11
+332 509
+333 13
+334 1062
+335 12
+336 577
+337 27
+338 378
+339 10
+340 2313
+341 9
+342 391
+343 13
+344 894
+345 17
+346 664
+347 9
+348 453
+349 6
+350 363
+351 15
+352 1115
+353 13
+354 1054
+355 8
+356 1108
+357 12
+358 354
+359 7
+360 363
+361 16
+362 344
+363 11
+364 1734
+365 12
+366 265
+367 10
+368 969
+369 16
+370 316
+371 12
+372 757
+373 7
+374 563
+375 15
+376 857
+377 9
+378 469
+379 9
+380 385
+381 12
+382 921
+383 15
+384 764
+385 14
+386 246
+387 6
+388 1108
+389 14
+390 230
+391 8
+392 266
+393 11
+394 641
+395 8
+396 719
+397 9
+398 243
+399 4
+400 1108
+401 7
+402 229
+403 7
+404 903
+405 7
+406 257
+407 12
+408 244
+409 3
+410 541
+411 6
+412 744
+413 8
+414 419
+415 8
+416 388
+417 19
+418 470
+419 14
+420 612
+421 6
+422 342
+423 3
+424 1179
+425 3
+426 116
+427 14
+428 207
+429 6
+430 255
+431 4
+432 288
+433 12
+434 343
+435 6
+436 1015
+437 3
+438 538
+439 10
+440 194
+441 6
+442 188
+443 15
+444 524
+445 7
+446 214
+447 7
+448 574
+449 6
+450 214
+451 5
+452 635
+453 9
+454 464
+455 5
+456 205
+457 9
+458 163
+459 2
+460 558
+461 4
+462 171
+463 14
+464 444
+465 11
+466 543
+467 5
+468 388
+469 6
+470 141
+471 4
+472 647
+473 3
+474 210
+475 4
+476 193
+477 7
+478 195
+479 7
+480 443
+481 10
+482 198
+483 3
+484 816
+485 6
+486 128
+487 9
+488 215
+489 9
+490 328
+491 7
+492 158
+493 11
+494 335
+495 8
+496 435
+497 6
+498 174
+499 1
+500 373
+501 5
+502 140
+503 7
+504 330
+505 9
+506 149
+507 5
+508 642
+509 3
+510 179
+511 3
+512 159
+513 8
+514 204
+515 7
+516 306
+517 4
+518 110
+519 5
+520 326
+521 6
+522 305
+523 6
+524 294
+525 7
+526 268
+527 5
+528 149
+529 4
+530 133
+531 2
+532 513
+533 10
+534 116
+535 5
+536 258
+537 4
+538 113
+539 4
+540 138
+541 6
+542 116
+544 485
+545 4
+546 93
+547 9
+548 299
+549 3
+550 256
+551 6
+552 92
+553 3
+554 175
+555 6
+556 253
+557 7
+558 95
+559 2
+560 128
+561 4
+562 206
+563 2
+564 465
+565 3
+566 69
+567 3
+568 157
+569 7
+570 97
+571 8
+572 118
+573 5
+574 130
+575 4
+576 301
+577 6
+578 177
+579 2
+580 397
+581 3
+582 80
+583 1
+584 128
+585 5
+586 52
+587 2
+588 72
+589 1
+590 84
+591 6
+592 323
+593 11
+594 77
+595 5
+596 205
+597 1
+598 244
+599 4
+600 69
+601 3
+602 89
+603 5
+604 254
+605 6
+606 147
+607 3
+608 83
+609 3
+610 77
+611 3
+612 194
+613 1
+614 98
+615 3
+616 243
+617 3
+618 50
+619 8
+620 188
+621 4
+622 67
+623 4
+624 123
+625 2
+626 50
+627 1
+628 239
+629 2
+630 51
+631 4
+632 65
+633 5
+634 188
+636 81
+637 3
+638 46
+639 3
+640 103
+641 1
+642 136
+643 3
+644 188
+645 3
+646 58
+648 122
+649 4
+650 47
+651 2
+652 155
+653 4
+654 71
+655 1
+656 71
+657 3
+658 50
+659 2
+660 177
+661 5
+662 66
+663 2
+664 183
+665 3
+666 50
+667 2
+668 53
+669 2
+670 115
+672 66
+673 2
+674 47
+675 1
+676 197
+677 2
+678 46
+679 3
+680 95
+681 3
+682 46
+683 3
+684 107
+685 1
+686 86
+687 2
+688 158
+689 4
+690 51
+691 1
+692 80
+694 56
+695 4
+696 40
+698 43
+699 3
+700 95
+701 2
+702 51
+703 2
+704 133
+705 1
+706 100
+707 2
+708 121
+709 2
+710 15
+711 3
+712 35
+713 2
+714 20
+715 3
+716 37
+717 2
+718 78
+720 55
+721 1
+722 42
+723 2
+724 218
+725 3
+726 23
+727 2
+728 26
+729 1
+730 64
+731 2
+732 65
+734 24
+735 2
+736 53
+737 1
+738 32
+739 1
+740 60
+742 81
+743 1
+744 77
+745 1
+746 47
+747 1
+748 62
+749 1
+750 19
+751 1
+752 86
+753 3
+754 40
+756 55
+757 2
+758 38
+759 1
+760 101
+761 1
+762 22
+764 67
+765 2
+766 35
+767 1
+768 38
+769 1
+770 22
+771 1
+772 82
+773 1
+774 73
+776 29
+777 1
+778 55
+780 23
+781 1
+782 16
+784 84
+785 3
+786 28
+788 59
+789 1
+790 33
+791 3
+792 24
+794 13
+795 1
+796 110
+797 2
+798 15
+800 22
+801 3
+802 29
+803 1
+804 87
+806 21
+808 29
+810 48
+812 28
+813 1
+814 58
+815 1
+816 48
+817 1
+818 31
+819 1
+820 66
+822 17
+823 2
+824 58
+826 10
+827 2
+828 25
+829 1
+830 29
+831 1
+832 63
+833 1
+834 26
+835 3
+836 52
+837 1
+838 18
+840 27
+841 2
+842 12
+843 1
+844 83
+845 1
+846 7
+847 1
+848 10
+850 26
+852 25
+853 1
+854 15
+856 27
+858 32
+859 1
+860 15
+862 43
+864 32
+865 1
+866 6
+868 39
+870 11
+872 25
+873 1
+874 10
+875 1
+876 20
+877 2
+878 19
+879 1
+880 30
+882 11
+884 53
+886 25
+887 1
+888 28
+890 6
+892 36
+894 10
+896 13
+898 14
+900 31
+902 14
+903 2
+904 43
+906 25
+908 9
+910 11
+911 1
+912 16
+913 1
+914 24
+916 27
+918 6
+920 15
+922 27
+923 1
+924 23
+926 13
+928 42
+929 1
+930 3
+932 27
+934 17
+936 8
+937 1
+938 11
+940 33
+942 4
+943 1
+944 18
+946 15
+948 13
+950 18
+952 12
+954 11
+956 21
+958 10
+960 13
+962 5
+964 32
+966 13
+968 8
+970 8
+971 1
+972 23
+973 2
+974 12
+975 1
+976 22
+978 7
+979 1
+980 14
+982 8
+984 22
+985 1
+986 6
+988 17
+989 1
+990 6
+992 13
+994 19
+996 11
+998 4
+1000 9
+1002 2
+1004 14
+1006 5
+1008 3
+1010 9
+1012 29
+1014 6
+1016 22
+1017 1
+1018 8
+1019 1
+1020 7
+1022 6
+1023 1
+1024 10
+1026 2
+1028 8
+1030 11
+1031 2
+1032 8
+1034 9
+1036 13
+1038 12
+1040 12
+1042 3
+1044 12
+1046 3
+1048 11
+1050 2
+1051 1
+1052 2
+1054 11
+1056 6
+1058 8
+1059 1
+1060 23
+1062 6
+1063 1
+1064 8
+1066 3
+1068 6
+1070 8
+1071 1
+1072 5
+1074 3
+1076 5
+1078 3
+1080 11
+1081 1
+1082 7
+1084 18
+1086 4
+1087 1
+1088 3
+1090 3
+1092 7
+1094 3
+1096 12
+1098 6
+1099 1
+1100 2
+1102 6
+1104 14
+1106 3
+1108 6
+1110 5
+1112 2
+1114 8
+1116 3
+1118 3
+1120 7
+1122 10
+1124 6
+1126 8
+1128 1
+1130 4
+1132 3
+1134 2
+1136 5
+1138 5
+1140 8
+1142 3
+1144 7
+1146 3
+1148 11
+1150 1
+1152 5
+1154 1
+1156 5
+1158 1
+1160 5
+1162 3
+1164 6
+1165 1
+1166 1
+1168 4
+1169 1
+1170 3
+1171 1
+1172 2
+1174 5
+1176 3
+1177 1
+1180 8
+1182 2
+1184 4
+1186 2
+1188 3
+1190 2
+1192 5
+1194 6
+1196 1
+1198 2
+1200 2
+1204 10
+1206 2
+1208 9
+1210 1
+1214 6
+1216 3
+1218 4
+1220 9
+1221 2
+1222 1
+1224 5
+1226 4
+1228 8
+1230 1
+1232 1
+1234 3
+1236 5
+1240 3
+1242 1
+1244 3
+1245 1
+1246 4
+1248 6
+1250 2
+1252 7
+1256 3
+1258 2
+1260 2
+1262 3
+1264 4
+1265 1
+1266 1
+1270 1
+1271 1
+1272 2
+1274 3
+1276 3
+1278 1
+1280 3
+1284 1
+1286 1
+1290 1
+1292 3
+1294 1
+1296 7
+1300 2
+1302 4
+1304 3
+1306 2
+1308 2
+1312 1
+1314 1
+1316 3
+1318 2
+1320 1
+1324 8
+1326 1
+1330 1
+1331 1
+1336 2
+1338 1
+1340 3
+1341 1
+1344 1
+1346 2
+1347 1
+1348 3
+1352 1
+1354 2
+1356 1
+1358 1
+1360 3
+1362 1
+1364 4
+1366 1
+1370 1
+1372 3
+1380 2
+1384 2
+1388 2
+1390 2
+1392 2
+1394 1
+1396 1
+1398 1
+1400 2
+1402 1
+1404 1
+1406 1
+1410 1
+1412 5
+1418 1
+1420 1
+1424 1
+1432 2
+1434 2
+1442 3
+1444 5
+1448 1
+1454 1
+1456 1
+1460 3
+1462 4
+1468 1
+1474 1
+1476 1
+1478 2
+1480 1
+1486 2
+1488 1
+1492 1
+1496 1
+1500 3
+1503 1
+1506 1
+1512 2
+1516 1
+1522 1
+1524 2
+1534 4
+1536 1
+1538 1
+1540 2
+1544 2
+1548 1
+1556 1
+1560 1
+1562 1
+1564 2
+1566 1
+1568 1
+1570 1
+1572 1
+1576 1
+1590 1
+1594 1
+1604 1
+1608 1
+1614 1
+1622 1
+1624 2
+1628 1
+1629 1
+1636 1
+1642 1
+1654 2
+1660 1
+1664 1
+1670 1
+1684 4
+1698 1
+1732 3
+1742 1
+1752 1
+1760 1
+1764 1
+1772 2
+1798 1
+1808 1
+1820 1
+1852 1
+1856 1
+1874 1
+1902 1
+1908 1
+1952 1
+2004 1
+2018 1
+2020 1
+2028 1
+2174 1
+2233 1
+2244 1
+2280 1
+2290 1
+2352 1
+2604 1
+4190 1
diff --git a/ppocr/utils/network.py b/ppocr/utils/network.py
new file mode 100644
index 0000000000000000000000000000000000000000..453abb693d4c0ed370c1031b677d5bf51661add9
--- /dev/null
+++ b/ppocr/utils/network.py
@@ -0,0 +1,82 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import tarfile
+import requests
+from tqdm import tqdm
+
+from ppocr.utils.logging import get_logger
+
+
+def download_with_progressbar(url, save_path):
+ logger = get_logger()
+ response = requests.get(url, stream=True)
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
+ block_size = 1024 # 1 Kibibyte
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
+ with open(save_path, 'wb') as file:
+ for data in response.iter_content(block_size):
+ progress_bar.update(len(data))
+ file.write(data)
+ progress_bar.close()
+ if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
+ logger.error("Something went wrong while downloading models")
+ sys.exit(0)
+
+
+def maybe_download(model_storage_directory, url):
+ # using custom model
+ tar_file_name_list = [
+ 'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
+ ]
+ if not os.path.exists(
+ os.path.join(model_storage_directory, 'inference.pdiparams')
+ ) or not os.path.exists(
+ os.path.join(model_storage_directory, 'inference.pdmodel')):
+ assert url.endswith('.tar'), 'Only supports tar compressed package'
+ tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
+ print('download {} to {}'.format(url, tmp_path))
+ os.makedirs(model_storage_directory, exist_ok=True)
+ download_with_progressbar(url, tmp_path)
+ with tarfile.open(tmp_path, 'r') as tarObj:
+ for member in tarObj.getmembers():
+ filename = None
+ for tar_file_name in tar_file_name_list:
+ if tar_file_name in member.name:
+ filename = tar_file_name
+ if filename is None:
+ continue
+ file = tarObj.extractfile(member)
+ with open(
+ os.path.join(model_storage_directory, filename),
+ 'wb') as f:
+ f.write(file.read())
+ os.remove(tmp_path)
+
+
+def is_link(s):
+ return s is not None and s.startswith('http')
+
+
+def confirm_model_dir_url(model_dir, default_model_dir, default_url):
+ url = default_url
+ if model_dir is None or is_link(model_dir):
+ if is_link(model_dir):
+ url = model_dir
+ file_name = url.split('/')[-1][:-4]
+ model_dir = default_model_dir
+ model_dir = os.path.join(model_dir, file_name)
+ return model_dir, url
diff --git a/ppstructure/MANIFEST.in b/ppstructure/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..2961e722b7cebe8e1912be2dd903fcdecb694019
--- /dev/null
+++ b/ppstructure/MANIFEST.in
@@ -0,0 +1,9 @@
+include LICENSE
+include README.md
+
+recursive-include ppocr/utils *.txt utility.py logging.py network.py
+recursive-include ppocr/data/ *.py
+recursive-include ppocr/postprocess *.py
+recursive-include tools/infer *.py
+recursive-include ppstructure *.py
+
diff --git a/ppstructure/README_ch.md b/ppstructure/README_ch.md
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..22505ad83c6dc58adf472f3db94cbf608b9bbd01 100644
--- a/ppstructure/README_ch.md
+++ b/ppstructure/README_ch.md
@@ -0,0 +1,30 @@
+# TableStructurer
+
+1. 代码使用
+```python
+import cv2
+from paddlestructure import PaddleStructure,draw_result
+
+table_engine = PaddleStructure(
+ output='./output/table',
+ show_log=True)
+
+img_path = '../doc/table/1.png'
+img = cv2.imread(img_path)
+result = table_engine(img)
+for line in result:
+ print(line)
+
+from PIL import Image
+
+font_path = 'path/tp/PaddleOCR/doc/fonts/simfang.ttf'
+image = Image.open(img_path).convert('RGB')
+im_show = draw_result(image, result,font_path=font_path)
+im_show = Image.fromarray(im_show)
+im_show.save('result.jpg')
+```
+
+2. 命令行使用
+```bash
+paddlestructure --image_dir=../doc/table/1.png
+```
diff --git a/ppstructure/__init__.py b/ppstructure/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7055bee443fb86648b80bcb892778a114bc47d71
--- /dev/null
+++ b/ppstructure/__init__.py
@@ -0,0 +1,17 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .paddlestructure import PaddleStructure, draw_result, to_excel
+
+__all__ = ['PaddleStructure', 'draw_result', 'to_excel']
diff --git a/ppstructure/layout/README.md b/ppstructure/layout/README.md
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/ppstructure/layout/README_ch.md b/ppstructure/layout/README_ch.md
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/ppstructure/paddlestructure.py b/ppstructure/paddlestructure.py
new file mode 100644
index 0000000000000000000000000000000000000000..57a53d6496f66771f1f6f7628751b4f0ac0fc3b5
--- /dev/null
+++ b/ppstructure/paddlestructure.py
@@ -0,0 +1,148 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import os
+import sys
+
+__dir__ = os.path.dirname(__file__)
+sys.path.append(__dir__)
+sys.path.append(os.path.join(__dir__, '..'))
+
+import cv2
+import numpy as np
+from pathlib import Path
+
+from ppocr.utils.logging import get_logger
+from ppstructure.predict_system import OCRSystem, save_res
+from ppstructure.table.predict_table import to_excel
+from ppstructure.utility import init_args, draw_result
+
+logger = get_logger()
+from ppocr.utils.utility import check_and_read_gif, get_image_file_list
+from ppocr.utils.network import maybe_download, download_with_progressbar, confirm_model_dir_url, is_link
+
+__all__ = ['PaddleStructure', 'draw_result', 'to_excel']
+
+VERSION = '2.1'
+BASE_DIR = os.path.expanduser("~/.paddlestructure/")
+
+model_urls = {
+ 'det': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar',
+ 'rec': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
+ 'structure': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar'
+
+}
+
+
+def parse_args(mMain=True):
+ import argparse
+ parser = init_args()
+ parser.add_help = mMain
+
+ for action in parser._actions:
+ if action.dest in ['rec_char_dict_path', 'structure_char_dict_path']:
+ action.default = None
+ if mMain:
+ return parser.parse_args()
+ else:
+ inference_args_dict = {}
+ for action in parser._actions:
+ inference_args_dict[action.dest] = action.default
+ return argparse.Namespace(**inference_args_dict)
+
+
+class PaddleStructure(OCRSystem):
+ def __init__(self, **kwargs):
+ params = parse_args(mMain=False)
+ params.__dict__.update(**kwargs)
+ if params.show_log:
+ logger.setLevel(logging.DEBUG)
+ params.use_angle_cls = False
+ # init model dir
+ params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
+ os.path.join(BASE_DIR, VERSION, 'det'),
+ model_urls['det'])
+ params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
+ os.path.join(BASE_DIR, VERSION, 'rec'),
+ model_urls['rec'])
+ params.structure_model_dir, structure_url = confirm_model_dir_url(params.structure_model_dir,
+ os.path.join(BASE_DIR, VERSION, 'structure'),
+ model_urls['structure'])
+ # download model
+ maybe_download(params.det_model_dir, det_url)
+ maybe_download(params.rec_model_dir, rec_url)
+ maybe_download(params.structure_model_dir, structure_url)
+
+ if params.rec_char_dict_path is None:
+ params.rec_char_type = 'EN'
+ if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')):
+ params.rec_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')
+ else:
+ params.rec_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_dict.txt')
+ if params.structure_char_dict_path is None:
+ if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')):
+ params.structure_char_dict_path = str(
+ Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')
+ else:
+ params.structure_char_dict_path = str(
+ Path(__file__).parent.parent / 'ppocr/utils/dict/table_structure_dict.txt')
+
+ print(params)
+ super().__init__(params)
+
+ def __call__(self, img):
+ if isinstance(img, str):
+ # download net image
+ if img.startswith('http'):
+ download_with_progressbar(img, 'tmp.jpg')
+ img = 'tmp.jpg'
+ image_file = img
+ img, flag = check_and_read_gif(image_file)
+ if not flag:
+ with open(image_file, 'rb') as f:
+ np_arr = np.frombuffer(f.read(), dtype=np.uint8)
+ img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
+ if img is None:
+ logger.error("error in loading image:{}".format(image_file))
+ return None
+ if isinstance(img, np.ndarray) and len(img.shape) == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ res = super().__call__(img)
+ return res
+
+
+def main():
+ # for cmd
+ args = parse_args(mMain=True)
+ image_dir = args.image_dir
+ save_folder = args.output
+ if image_dir.startswith('http'):
+ download_with_progressbar(image_dir, 'tmp.jpg')
+ image_file_list = ['tmp.jpg']
+ else:
+ image_file_list = get_image_file_list(args.image_dir)
+ if len(image_file_list) == 0:
+ logger.error('no images find in {}'.format(args.image_dir))
+ return
+
+ structure_engine = PaddleStructure(**(args.__dict__))
+ for img_path in image_file_list:
+ img_name = os.path.basename(img_path).split('.')[0]
+ logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
+ result = structure_engine(img_path)
+ for item in result:
+ logger.info(item['res'])
+ save_res(result, save_folder, img_name)
+ logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
\ No newline at end of file
diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..2cdfcce2eb3ad4abe4407f781eb99e3591ecebde 100644
--- a/ppstructure/predict_system.py
+++ b/ppstructure/predict_system.py
@@ -0,0 +1,132 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import subprocess
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+import cv2
+import numpy as np
+import time
+
+import layoutparser as lp
+
+from ppocr.utils.utility import get_image_file_list, check_and_read_gif
+from ppocr.utils.logging import get_logger
+from tools.infer.predict_system import TextSystem
+from ppstructure.table.predict_table import TableSystem, to_excel
+from ppstructure.utility import parse_args,draw_result
+
+logger = get_logger()
+
+
+class OCRSystem(object):
+ def __init__(self, args):
+ args.det_limit_type = 'resize_long'
+ args.drop_score = 0
+ self.text_system = TextSystem(args)
+ self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
+ self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
+ threshold=0.5, enable_mkldnn=args.enable_mkldnn,
+ enforce_cpu=not args.use_gpu, thread_num=args.cpu_threads)
+ self.use_angle_cls = args.use_angle_cls
+ self.drop_score = args.drop_score
+
+ def __call__(self, img):
+ ori_im = img.copy()
+ layout_res = self.table_layout.detect(img[..., ::-1])
+ res_list = []
+ for region in layout_res:
+ x1, y1, x2, y2 = region.coordinates
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
+ roi_img = ori_im[y1:y2, x1:x2, :]
+ if region.type == 'Table':
+ res = self.table_system(roi_img)
+ elif region.type == 'Figure':
+ continue
+ else:
+ filter_boxes, filter_rec_res = self.text_system(roi_img)
+ filter_boxes = [x + [x1, y1] for x in filter_boxes]
+ filter_boxes = [x.reshape(-1).tolist() for x in filter_boxes]
+
+ res = (filter_boxes, filter_rec_res)
+ res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res})
+ return res_list
+
+def save_res(res, save_folder, img_name):
+ excel_save_folder = os.path.join(save_folder, img_name)
+ os.makedirs(excel_save_folder, exist_ok=True)
+ # save res
+ for region in res:
+ if region['type'] == 'Table':
+ excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
+ to_excel(region['res'], excel_path)
+ elif region['type'] == 'Figure':
+ pass
+ else:
+ with open(os.path.join(excel_save_folder, 'res.txt'), 'a', encoding='utf8') as f:
+ for box, rec_res in zip(region['res'][0], region['res'][1]):
+ f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
+
+
+def main(args):
+ image_file_list = get_image_file_list(args.image_dir)
+ image_file_list = image_file_list
+ image_file_list = image_file_list[args.process_id::args.total_process_num]
+ save_folder = args.output
+ os.makedirs(save_folder, exist_ok=True)
+
+ structure_sys = OCRSystem(args)
+ img_num = len(image_file_list)
+ for i, image_file in enumerate(image_file_list):
+ logger.info("[{}/{}] {}".format(i, img_num, image_file))
+ img, flag = check_and_read_gif(image_file)
+ img_name = os.path.basename(image_file).split('.')[0]
+
+ if not flag:
+ img = cv2.imread(image_file)
+ if img is None:
+ logger.error("error in loading image:{}".format(image_file))
+ continue
+ starttime = time.time()
+ res = structure_sys(img)
+ save_res(res, save_folder, img_name)
+ draw_img = draw_result(img,res, args.vis_font_path)
+ cv2.imwrite(os.path.join(save_folder, img_name, 'show.jpg'), draw_img)
+ logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
+ elapse = time.time() - starttime
+ logger.info("Predict time : {:.3f}s".format(elapse))
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ if args.use_mp:
+ p_list = []
+ total_process_num = args.total_process_num
+ for process_id in range(total_process_num):
+ cmd = [sys.executable, "-u"] + sys.argv + [
+ "--process_id={}".format(process_id),
+ "--use_mp={}".format(False)
+ ]
+ p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
+ p_list.append(p)
+ for p in p_list:
+ p.wait()
+ else:
+ main(args)
diff --git a/ppstructure/setup.py b/ppstructure/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e68b2e44140f6ad5a13661349666d17cfe45524
--- /dev/null
+++ b/ppstructure/setup.py
@@ -0,0 +1,72 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+from setuptools import setup
+from io import open
+import shutil
+
+with open('../requirements.txt', encoding="utf-8-sig") as f:
+ requirements = f.readlines()
+ requirements.append('tqdm')
+ requirements.append('layoutparser')
+ requirements.append('iopath')
+
+
+def readme():
+ with open('README_ch.md', encoding="utf-8-sig") as f:
+ README = f.read()
+ return README
+
+
+shutil.copytree('../ppstructure/table', './ppstructure/table')
+shutil.copyfile('../ppstructure/predict_system.py', './ppstructure/predict_system.py')
+shutil.copyfile('../ppstructure/utility.py', './ppstructure/utility.py')
+shutil.copytree('../ppocr', './ppocr')
+shutil.copytree('../tools', './tools')
+shutil.copyfile('../LICENSE', './LICENSE')
+
+setup(
+ name='paddlestructure',
+ packages=['paddlestructure'],
+ package_dir={'paddlestructure': ''},
+ include_package_data=True,
+ entry_points={"console_scripts": ["paddlestructure= paddlestructure.paddlestructure:main"]},
+ version='1.0',
+ install_requires=requirements,
+ license='Apache License 2.0',
+ description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
+ long_description=readme(),
+ long_description_content_type='text/markdown',
+ url='https://github.com/PaddlePaddle/PaddleOCR',
+ download_url='https://github.com/PaddlePaddle/PaddleOCR.git',
+ keywords=[
+ 'ocr textdetection textrecognition paddleocr crnn east star-net rosetta ocrlite db chineseocr chinesetextdetection chinesetextrecognition'
+ ],
+ classifiers=[
+ 'Intended Audience :: Developers', 'Operating System :: OS Independent',
+ 'Natural Language :: Chinese (Simplified)',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.2',
+ 'Programming Language :: Python :: 3.3',
+ 'Programming Language :: Python :: 3.4',
+ 'Programming Language :: Python :: 3.5',
+ 'Programming Language :: Python :: 3.6',
+ 'Programming Language :: Python :: 3.7', 'Topic :: Utilities'
+ ], )
+
+shutil.rmtree('ppocr')
+shutil.rmtree('tools')
+shutil.rmtree('ppstructure')
+os.remove('LICENSE')
diff --git a/ppstructure/table/README_ch.md b/ppstructure/table/README_ch.md
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..105231068a99eb6c012a125ba3fb65934c5d4ac6 100644
--- a/ppstructure/table/README_ch.md
+++ b/ppstructure/table/README_ch.md
@@ -0,0 +1,15 @@
+# 表格结构和内容预测
+
+先cd到PaddleOCR/ppstructure目录下
+
+预测
+```python
+python3 table/predict_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs/PMC3006023_004_00.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --table_output ../output/table
+```
+运行完成后,每张图片的excel表格会保存到table_output字段指定的目录下
+
+评估
+
+```python
+python3 table/eval_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
+```
diff --git a/ppstructure/table/__init__.py b/ppstructure/table/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d11e265597c7c8e39098a228108da3bb954b892
--- /dev/null
+++ b/ppstructure/table/__init__.py
@@ -0,0 +1,13 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/ppstructure/table/eval_table.py b/ppstructure/table/eval_table.py
new file mode 100755
index 0000000000000000000000000000000000000000..1bcbaa8d0d0b2669828dc6b19c3370a30c522ede
--- /dev/null
+++ b/ppstructure/table/eval_table.py
@@ -0,0 +1,69 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+
+import cv2
+import json
+from tqdm import tqdm
+from ppstructure.table.table_metric import TEDS
+from ppstructure.table.predict_table import TableSystem
+from ppstructure.utility import init_args
+
+
+def parse_args():
+ parser = init_args()
+ parser.add_argument("--gt_path", type=str)
+ return parser.parse_args()
+
+def main(gt_path, img_root, args):
+ teds = TEDS(n_jobs=16)
+
+ text_sys = TableSystem(args)
+ jsons_gt = json.load(open(gt_path)) # gt
+ pred_htmls = []
+ gt_htmls = []
+ for img_name in tqdm(jsons_gt):
+ # read image
+ img = cv2.imread(os.path.join(img_root,img_name))
+ pred_html = text_sys(img)
+ pred_htmls.append(pred_html)
+
+ gt_structures, gt_bboxes, gt_contents, contents_with_block = jsons_gt[img_name]
+ gt_html, gt = get_gt_html(gt_structures, contents_with_block)
+ gt_htmls.append(gt_html)
+ scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
+ print('teds:', sum(scores) / len(scores))
+
+
+def get_gt_html(gt_structures, contents_with_block):
+ end_html = []
+ td_index = 0
+ for tag in gt_structures:
+ if ' | ' in tag:
+ if contents_with_block[td_index] != []:
+ end_html.extend(contents_with_block[td_index])
+ end_html.append(tag)
+ td_index += 1
+ else:
+ end_html.append(tag)
+ return ''.join(end_html), end_html
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ main(args.gt_path,args.image_dir, args)
diff --git a/ppstructure/table/matcher.py b/ppstructure/table/matcher.py
new file mode 100755
index 0000000000000000000000000000000000000000..c3b56384403f5fd92a8db4b4bb378a6d55e5a76c
--- /dev/null
+++ b/ppstructure/table/matcher.py
@@ -0,0 +1,192 @@
+import json
+def distance(box_1, box_2):
+ x1, y1, x2, y2 = box_1
+ x3, y3, x4, y4 = box_2
+ dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
+ dis_2 = abs(x3 - x1) + abs(y3 - y1)
+ dis_3 = abs(x4- x2) + abs(y4 - y2)
+ return dis + min(dis_2, dis_3)
+
+def compute_iou(rec1, rec2):
+ """
+ computing IoU
+ :param rec1: (y0, x0, y1, x1), which reflects
+ (top, left, bottom, right)
+ :param rec2: (y0, x0, y1, x1)
+ :return: scala value of IoU
+ """
+ # computing area of each rectangles
+ S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
+ S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
+
+ # computing the sum_area
+ sum_area = S_rec1 + S_rec2
+
+ # find the each edge of intersect rectangle
+ left_line = max(rec1[1], rec2[1])
+ right_line = min(rec1[3], rec2[3])
+ top_line = max(rec1[0], rec2[0])
+ bottom_line = min(rec1[2], rec2[2])
+
+ # judge if there is an intersect
+ if left_line >= right_line or top_line >= bottom_line:
+ return 0.0
+ else:
+ intersect = (right_line - left_line) * (bottom_line - top_line)
+ return (intersect / (sum_area - intersect))*1.0
+
+
+
+def matcher_merge(ocr_bboxes, pred_bboxes):
+ all_dis = []
+ ious = []
+ matched = {}
+ for i, gt_box in enumerate(ocr_bboxes):
+ distances = []
+ for j, pred_box in enumerate(pred_bboxes):
+ # compute l1 distence and IOU between two boxes
+ distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box)))
+ sorted_distances = distances.copy()
+ # select nearest cell
+ sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0]))
+ if distances.index(sorted_distances[0]) not in matched.keys():
+ matched[distances.index(sorted_distances[0])] = [i]
+ else:
+ matched[distances.index(sorted_distances[0])].append(i)
+ return matched#, sum(ious) / len(ious)
+
+def complex_num(pred_bboxes):
+ complex_nums = []
+ for bbox in pred_bboxes:
+ distances = []
+ temp_ious = []
+ for pred_bbox in pred_bboxes:
+ if bbox != pred_bbox:
+ distances.append(distance(bbox, pred_bbox))
+ temp_ious.append(compute_iou(bbox, pred_bbox))
+ complex_nums.append(temp_ious[distances.index(min(distances))])
+ return sum(complex_nums) / len(complex_nums)
+
+def get_rows(pred_bboxes):
+ pre_bbox = pred_bboxes[0]
+ res = []
+ step = 0
+ for i in range(len(pred_bboxes)):
+ bbox = pred_bboxes[i]
+ if bbox[1] - pre_bbox[1] > 2 or bbox[0] - pre_bbox[0] < 0:
+ break
+ else:
+ res.append(bbox)
+ step += 1
+ for i in range(step):
+ pred_bboxes.pop(0)
+ return res, pred_bboxes
+def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
+ ys_1 = []
+ ys_2 = []
+ for box in pred_bboxes:
+ ys_1.append(box[1])
+ ys_2.append(box[3])
+ min_y_1 = sum(ys_1) / len(ys_1)
+ min_y_2 = sum(ys_2) / len(ys_2)
+ re_boxes = []
+ for box in pred_bboxes:
+ box[1] = min_y_1
+ box[3] = min_y_2
+ re_boxes.append(box)
+ return re_boxes
+
+def matcher_refine_row(gt_bboxes, pred_bboxes):
+ before_refine_pred_bboxes = pred_bboxes.copy()
+ pred_bboxes = []
+ while(len(before_refine_pred_bboxes) != 0):
+ row_bboxes, before_refine_pred_bboxes = get_rows(before_refine_pred_bboxes)
+ print(row_bboxes)
+ pred_bboxes.extend(refine_rows(row_bboxes))
+ all_dis = []
+ ious = []
+ matched = {}
+ for i, gt_box in enumerate(gt_bboxes):
+ distances = []
+ #temp_ious = []
+ for j, pred_box in enumerate(pred_bboxes):
+ distances.append(distance(gt_box, pred_box))
+ #temp_ious.append(compute_iou(gt_box, pred_box))
+ #all_dis.append(min(distances))
+ #ious.append(temp_ious[distances.index(min(distances))])
+ if distances.index(min(distances)) not in matched.keys():
+ matched[distances.index(min(distances))] = [i]
+ else:
+ matched[distances.index(min(distances))].append(i)
+ return matched#, sum(ious) / len(ious)
+
+
+
+#先挑选出一行,再进行匹配
+def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes):
+ gt_box_index = 0
+ delete_gt_bboxes = gt_bboxes.copy()
+ match_bboxes_ready = []
+ matched = {}
+ while(len(delete_gt_bboxes) != 0):
+ row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes)
+ row_bboxes = sorted(row_bboxes, key = lambda key: key[0])
+ if len(pred_bboxes_rows) > 0:
+ match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
+ print(row_bboxes)
+ for i, gt_box in enumerate(row_bboxes):
+ #print(gt_box)
+ pred_distances = []
+ distances = []
+ for pred_bbox in pred_bboxes:
+ pred_distances.append(distance(gt_box, pred_bbox))
+ for j, pred_box in enumerate(match_bboxes_ready):
+ distances.append(distance(gt_box, pred_box))
+ index = pred_distances.index(min(distances))
+ #print('index', index)
+ if index not in matched.keys():
+ matched[index] = [gt_box_index]
+ else:
+ matched[index].append(gt_box_index)
+ gt_box_index += 1
+ return matched
+
+def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
+ '''
+ gt_bboxes: 排序后
+ pred_bboxes:
+ '''
+ pre_bbox = gt_bboxes[0]
+ matched = {}
+ match_bboxes_ready = []
+ match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
+ for i, gt_box in enumerate(gt_bboxes):
+
+ pred_distances = []
+ for pred_bbox in pred_bboxes:
+ pred_distances.append(distance(gt_box, pred_bbox))
+ distances = []
+ gap_pre = gt_box[1] - pre_bbox[1]
+ gap_pre_1 = gt_box[0] - pre_bbox[2]
+ #print(gap_pre, len(pred_bboxes_rows))
+ if (gap_pre_1 < 0 and len(pred_bboxes_rows) > 0):
+ match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
+ if len(pred_bboxes_rows) == 1:
+ match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
+ if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) > 0:
+ match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
+ if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) == 0:
+ break
+ #print(match_bboxes_ready)
+ for j, pred_box in enumerate(match_bboxes_ready):
+ distances.append(distance(gt_box, pred_box))
+ index = pred_distances.index(min(distances))
+ #print(gt_box, index)
+ #match_bboxes_ready.pop(distances.index(min(distances)))
+ print(gt_box, match_bboxes_ready[distances.index(min(distances))])
+ if index not in matched.keys():
+ matched[index] = [i]
+ else:
+ matched[index].append(i)
+ pre_bbox = gt_box
+ return matched
diff --git a/ppstructure/table/predict_structure.py b/ppstructure/table/predict_structure.py
new file mode 100755
index 0000000000000000000000000000000000000000..6e680b3574ba28b439acad34424b51dfdc02078c
--- /dev/null
+++ b/ppstructure/table/predict_structure.py
@@ -0,0 +1,141 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import cv2
+import numpy as np
+import math
+import time
+import traceback
+import paddle
+
+import tools.infer.utility as utility
+from ppocr.data import create_operators, transform
+from ppocr.postprocess import build_post_process
+from ppocr.utils.logging import get_logger
+from ppocr.utils.utility import get_image_file_list, check_and_read_gif
+
+logger = get_logger()
+
+
+class TableStructurer(object):
+ def __init__(self, args):
+ pre_process_list = [{
+ 'ResizeTableImage': {
+ 'max_len': args.structure_max_len
+ }
+ }, {
+ 'NormalizeImage': {
+ 'std': [0.229, 0.224, 0.225],
+ 'mean': [0.485, 0.456, 0.406],
+ 'scale': '1./255.',
+ 'order': 'hwc'
+ }
+ }, {
+ 'PaddingTableImage': None
+ }, {
+ 'ToCHWImage': None
+ }, {
+ 'KeepKeys': {
+ 'keep_keys': ['image']
+ }
+ }]
+ postprocess_params = {
+ 'name': 'TableLabelDecode',
+ "character_type": args.structure_char_type,
+ "character_dict_path": args.structure_char_dict_path,
+ "max_text_length": args.structure_max_text_length,
+ "max_elem_length": args.structure_max_elem_length,
+ "max_cell_num": args.structure_max_cell_num
+ }
+
+ self.preprocess_op = create_operators(pre_process_list)
+ self.postprocess_op = build_post_process(postprocess_params)
+ self.predictor, self.input_tensor, self.output_tensors = \
+ utility.create_predictor(args, 'structure', logger)
+
+ def __call__(self, img):
+ ori_im = img.copy()
+ data = {'image': img}
+ data = transform(data, self.preprocess_op)
+ img = data[0]
+ if img is None:
+ return None, 0
+ img = np.expand_dims(img, axis=0)
+ img = img.copy()
+ starttime = time.time()
+
+ self.input_tensor.copy_from_cpu(img)
+ self.predictor.run()
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+
+ preds = {}
+ preds['structure_probs'] = outputs[1]
+ preds['loc_preds'] = outputs[0]
+
+ post_result = self.postprocess_op(preds)
+
+ structure_str_list = post_result['structure_str_list']
+ res_loc = post_result['res_loc']
+ imgh, imgw = ori_im.shape[0:2]
+ res_loc_final = []
+ for rno in range(len(res_loc[0])):
+ x0, y0, x1, y1 = res_loc[0][rno]
+ left = max(int(imgw * x0), 0)
+ top = max(int(imgh * y0), 0)
+ right = min(int(imgw * x1), imgw - 1)
+ bottom = min(int(imgh * y1), imgh - 1)
+ res_loc_final.append([left, top, right, bottom])
+
+ structure_str_list = structure_str_list[0][:-1]
+ structure_str_list = ['', '', ''] + structure_str_list + [' ', '', '']
+
+ elapse = time.time() - starttime
+ return (structure_str_list, res_loc_final), elapse
+
+
+def main(args):
+ image_file_list = get_image_file_list(args.image_dir)
+ table_structurer = TableStructurer(args)
+ count = 0
+ total_time = 0
+ for image_file in image_file_list:
+ img, flag = check_and_read_gif(image_file)
+ if not flag:
+ img = cv2.imread(image_file)
+ if img is None:
+ logger.info("error in loading image:{}".format(image_file))
+ continue
+ structure_res, elapse = table_structurer(img)
+
+ logger.info("result: {}".format(structure_res))
+
+ if count > 0:
+ total_time += elapse
+ count += 1
+ logger.info("Predict time of {}: {}".format(image_file, elapse))
+
+
+if __name__ == "__main__":
+ main(utility.parse_args())
diff --git a/ppstructure/table/predict_table.py b/ppstructure/table/predict_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4edd22c3de4df5f0ba3e0a1e28a8c346a48d4ee
--- /dev/null
+++ b/ppstructure/table/predict_table.py
@@ -0,0 +1,221 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import subprocess
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+import cv2
+import copy
+import numpy as np
+import time
+import tools.infer.predict_rec as predict_rec
+import tools.infer.predict_det as predict_det
+from ppocr.utils.utility import get_image_file_list, check_and_read_gif
+from ppocr.utils.logging import get_logger
+from ppstructure.table.matcher import distance, compute_iou
+from ppstructure.utility import parse_args
+import ppstructure.table.predict_structure as predict_strture
+
+logger = get_logger()
+
+
+def expand(pix, det_box, shape):
+ x0, y0, x1, y1 = det_box
+ # print(shape)
+ h, w, c = shape
+ tmp_x0 = x0 - pix
+ tmp_x1 = x1 + pix
+ tmp_y0 = y0 - pix
+ tmp_y1 = y1 + pix
+ x0_ = tmp_x0 if tmp_x0 >= 0 else 0
+ x1_ = tmp_x1 if tmp_x1 <= w else w
+ y0_ = tmp_y0 if tmp_y0 >= 0 else 0
+ y1_ = tmp_y1 if tmp_y1 <= h else h
+ return x0_, y0_, x1_, y1_
+
+
+class TableSystem(object):
+ def __init__(self, args, text_detector=None, text_recognizer=None):
+ self.text_detector = predict_det.TextDetector(args) if text_detector is None else text_detector
+ self.text_recognizer = predict_rec.TextRecognizer(args) if text_recognizer is None else text_recognizer
+ self.table_structurer = predict_strture.TableStructurer(args)
+
+ def __call__(self, img):
+ ori_im = img.copy()
+ structure_res, elapse = self.table_structurer(copy.deepcopy(img))
+ dt_boxes, elapse = self.text_detector(copy.deepcopy(img))
+ dt_boxes = sorted_boxes(dt_boxes)
+
+ r_boxes = []
+ for box in dt_boxes:
+ x_min = box[:, 0].min() - 1
+ x_max = box[:, 0].max() + 1
+ y_min = box[:, 1].min() - 1
+ y_max = box[:, 1].max() + 1
+ box = [x_min, y_min, x_max, y_max]
+ r_boxes.append(box)
+ dt_boxes = np.array(r_boxes)
+
+ logger.debug("dt_boxes num : {}, elapse : {}".format(
+ len(dt_boxes), elapse))
+ if dt_boxes is None:
+ return None, None
+ img_crop_list = []
+
+ for i in range(len(dt_boxes)):
+ det_box = dt_boxes[i]
+ x0, y0, x1, y1 = expand(2, det_box, ori_im.shape)
+ text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :]
+ img_crop_list.append(text_rect)
+ rec_res, elapse = self.text_recognizer(img_crop_list)
+ logger.debug("rec_res num : {}, elapse : {}".format(
+ len(rec_res), elapse))
+
+ pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
+ return pred_html
+
+ def rebuild_table(self, structure_res, dt_boxes, rec_res):
+ pred_structures, pred_bboxes = structure_res
+ matched_index = self.match_result(dt_boxes, pred_bboxes)
+ pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
+ return pred_html, pred
+
+ def match_result(self, dt_boxes, pred_bboxes):
+ matched = {}
+ for i, gt_box in enumerate(dt_boxes):
+ # gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
+ distances = []
+ for j, pred_box in enumerate(pred_bboxes):
+ distances.append(
+ (distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) # 获取两两cell之间的L1距离和 1- IOU
+ sorted_distances = distances.copy()
+ # 根据距离和IOU挑选最"近"的cell
+ sorted_distances = sorted(sorted_distances, key=lambda item: (item[1], item[0]))
+ if distances.index(sorted_distances[0]) not in matched.keys():
+ matched[distances.index(sorted_distances[0])] = [i]
+ else:
+ matched[distances.index(sorted_distances[0])].append(i)
+ return matched
+
+ def get_pred_html(self, pred_structures, matched_index, ocr_contents):
+ end_html = []
+ td_index = 0
+ for tag in pred_structures:
+ if ' | ' in tag:
+ if td_index in matched_index.keys():
+ b_with = False
+ if '' in ocr_contents[matched_index[td_index][0]] and len(matched_index[td_index]) > 1:
+ b_with = True
+ end_html.extend('')
+ for i, td_index_index in enumerate(matched_index[td_index]):
+ content = ocr_contents[td_index_index][0]
+ if len(matched_index[td_index]) > 1:
+ if len(content) == 0:
+ continue
+ if content[0] == ' ':
+ content = content[1:]
+ if '' in content:
+ content = content[3:]
+ if '' in content:
+ content = content[:-4]
+ if len(content) == 0:
+ continue
+ if i != len(matched_index[td_index]) - 1 and ' ' != content[-1]:
+ content += ' '
+ end_html.extend(content)
+ if b_with:
+ end_html.extend('')
+
+ end_html.append(tag)
+ td_index += 1
+ else:
+ end_html.append(tag)
+ return ''.join(end_html), end_html
+
+
+def sorted_boxes(dt_boxes):
+ """
+ Sort text boxes in order from top to bottom, left to right
+ args:
+ dt_boxes(array):detected text boxes with shape [4, 2]
+ return:
+ sorted boxes(array) with shape [4, 2]
+ """
+ num_boxes = dt_boxes.shape[0]
+ sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
+ _boxes = list(sorted_boxes)
+
+ for i in range(num_boxes - 1):
+ if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
+ (_boxes[i + 1][0][0] < _boxes[i][0][0]):
+ tmp = _boxes[i]
+ _boxes[i] = _boxes[i + 1]
+ _boxes[i + 1] = tmp
+ return _boxes
+
+
+def to_excel(html_table, excel_path):
+ from tablepyxl import tablepyxl
+ tablepyxl.document_to_xl(html_table, excel_path)
+
+
+def main(args):
+ image_file_list = get_image_file_list(args.image_dir)
+ image_file_list = image_file_list[args.process_id::args.total_process_num]
+ os.makedirs(args.output, exist_ok=True)
+
+ text_sys = TableSystem(args)
+ img_num = len(image_file_list)
+ for i, image_file in enumerate(image_file_list):
+ logger.info("[{}/{}] {}".format(i, img_num, image_file))
+ img, flag = check_and_read_gif(image_file)
+ excel_path = os.path.join(args.table_output, os.path.basename(image_file).split('.')[0] + '.xlsx')
+ if not flag:
+ img = cv2.imread(image_file)
+ if img is None:
+ logger.error("error in loading image:{}".format(image_file))
+ continue
+ starttime = time.time()
+ pred_html = text_sys(img)
+
+ to_excel(pred_html, excel_path)
+ logger.info('excel saved to {}'.format(excel_path))
+ logger.info(pred_html)
+ elapse = time.time() - starttime
+ logger.info("Predict time : {:.3f}s".format(elapse))
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ if args.use_mp:
+ p_list = []
+ total_process_num = args.total_process_num
+ for process_id in range(total_process_num):
+ cmd = [sys.executable, "-u"] + sys.argv + [
+ "--process_id={}".format(process_id),
+ "--use_mp={}".format(False)
+ ]
+ p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
+ p_list.append(p)
+ for p in p_list:
+ p.wait()
+ else:
+ main(args)
diff --git a/ppstructure/table/table_metric/__init__.py b/ppstructure/table/table_metric/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..de2d307430f68881ece1e41357d3b2f423e07ddd
--- /dev/null
+++ b/ppstructure/table/table_metric/__init__.py
@@ -0,0 +1,16 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = ['TEDS']
+from .table_metric import TEDS
\ No newline at end of file
diff --git a/ppstructure/table/table_metric/parallel.py b/ppstructure/table/table_metric/parallel.py
new file mode 100755
index 0000000000000000000000000000000000000000..f7326a1f506ca5fb7b3e97b0d077dc016e7eb7c7
--- /dev/null
+++ b/ppstructure/table/table_metric/parallel.py
@@ -0,0 +1,51 @@
+from tqdm import tqdm
+from concurrent.futures import ProcessPoolExecutor, as_completed
+
+
+def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0):
+ """
+ A parallel version of the map function with a progress bar.
+ Args:
+ array (array-like): An array to iterate over.
+ function (function): A python function to apply to the elements of array
+ n_jobs (int, default=16): The number of cores to use
+ use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
+ keyword arguments to function
+ front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
+ Useful for catching bugs
+ Returns:
+ [function(array[0]), function(array[1]), ...]
+ """
+ # We run the first few iterations serially to catch bugs
+ if front_num > 0:
+ front = [function(**a) if use_kwargs else function(a)
+ for a in array[:front_num]]
+ else:
+ front = []
+ # If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
+ if n_jobs == 1:
+ return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
+ # Assemble the workers
+ with ProcessPoolExecutor(max_workers=n_jobs) as pool:
+ # Pass the elements of array into function
+ if use_kwargs:
+ futures = [pool.submit(function, **a) for a in array[front_num:]]
+ else:
+ futures = [pool.submit(function, a) for a in array[front_num:]]
+ kwargs = {
+ 'total': len(futures),
+ 'unit': 'it',
+ 'unit_scale': True,
+ 'leave': True
+ }
+ # Print out the progress as tasks complete
+ for f in tqdm(as_completed(futures), **kwargs):
+ pass
+ out = []
+ # Get the results from the futures.
+ for i, future in tqdm(enumerate(futures)):
+ try:
+ out.append(future.result())
+ except Exception as e:
+ out.append(e)
+ return front + out
diff --git a/ppstructure/table/table_metric/table_metric.py b/ppstructure/table/table_metric/table_metric.py
new file mode 100755
index 0000000000000000000000000000000000000000..9aca98ad785d4614a803fa5a277a6e4a27b3b078
--- /dev/null
+++ b/ppstructure/table/table_metric/table_metric.py
@@ -0,0 +1,247 @@
+# Copyright 2020 IBM
+# Author: peter.zhong@au1.ibm.com
+#
+# This is free software; you can redistribute it and/or modify
+# it under the terms of the Apache 2.0 License.
+#
+# This software is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# Apache 2.0 License for more details.
+
+import distance
+from apted import APTED, Config
+from apted.helpers import Tree
+from lxml import etree, html
+from collections import deque
+from .parallel import parallel_process
+from tqdm import tqdm
+
+
+class TableTree(Tree):
+ def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
+ self.tag = tag
+ self.colspan = colspan
+ self.rowspan = rowspan
+ self.content = content
+ self.children = list(children)
+
+ def bracket(self):
+ """Show tree using brackets notation"""
+ if self.tag == 'td':
+ result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
+ (self.tag, self.colspan, self.rowspan, self.content)
+ else:
+ result = '"tag": %s' % self.tag
+ for child in self.children:
+ result += child.bracket()
+ return "{{{}}}".format(result)
+
+
+class CustomConfig(Config):
+ @staticmethod
+ def maximum(*sequences):
+ """Get maximum possible value
+ """
+ return max(map(len, sequences))
+
+ def normalized_distance(self, *sequences):
+ """Get distance from 0 to 1
+ """
+ return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
+
+ def rename(self, node1, node2):
+ """Compares attributes of trees"""
+ #print(node1.tag)
+ if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
+ return 1.
+ if node1.tag == 'td':
+ if node1.content or node2.content:
+ #print(node1.content, )
+ return self.normalized_distance(node1.content, node2.content)
+ return 0.
+
+
+
+class CustomConfig_del_short(Config):
+ @staticmethod
+ def maximum(*sequences):
+ """Get maximum possible value
+ """
+ return max(map(len, sequences))
+
+ def normalized_distance(self, *sequences):
+ """Get distance from 0 to 1
+ """
+ return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
+
+ def rename(self, node1, node2):
+ """Compares attributes of trees"""
+ if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
+ return 1.
+ if node1.tag == 'td':
+ if node1.content or node2.content:
+ #print('before')
+ #print(node1.content, node2.content)
+ #print('after')
+ node1_content = node1.content
+ node2_content = node2.content
+ if len(node1_content) < 3:
+ node1_content = ['####']
+ if len(node2_content) < 3:
+ node2_content = ['####']
+ return self.normalized_distance(node1_content, node2_content)
+ return 0.
+
+class CustomConfig_del_block(Config):
+ @staticmethod
+ def maximum(*sequences):
+ """Get maximum possible value
+ """
+ return max(map(len, sequences))
+
+ def normalized_distance(self, *sequences):
+ """Get distance from 0 to 1
+ """
+ return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
+
+ def rename(self, node1, node2):
+ """Compares attributes of trees"""
+ if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
+ return 1.
+ if node1.tag == 'td':
+ if node1.content or node2.content:
+
+ node1_content = node1.content
+ node2_content = node2.content
+ while ' ' in node1_content:
+ print(node1_content.index(' '))
+ node1_content.pop(node1_content.index(' '))
+ while ' ' in node2_content:
+ print(node2_content.index(' '))
+ node2_content.pop(node2_content.index(' '))
+ return self.normalized_distance(node1_content, node2_content)
+ return 0.
+
+class TEDS(object):
+ ''' Tree Edit Distance basead Similarity
+ '''
+
+ def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
+ assert isinstance(n_jobs, int) and (
+ n_jobs >= 1), 'n_jobs must be an integer greather than 1'
+ self.structure_only = structure_only
+ self.n_jobs = n_jobs
+ self.ignore_nodes = ignore_nodes
+ self.__tokens__ = []
+
+ def tokenize(self, node):
+ ''' Tokenizes table cells
+ '''
+ self.__tokens__.append('<%s>' % node.tag)
+ if node.text is not None:
+ self.__tokens__ += list(node.text)
+ for n in node.getchildren():
+ self.tokenize(n)
+ if node.tag != 'unk':
+ self.__tokens__.append('%s>' % node.tag)
+ if node.tag != 'td' and node.tail is not None:
+ self.__tokens__ += list(node.tail)
+
+ def load_html_tree(self, node, parent=None):
+ ''' Converts HTML tree to the format required by apted
+ '''
+ global __tokens__
+ if node.tag == 'td':
+ if self.structure_only:
+ cell = []
+ else:
+ self.__tokens__ = []
+ self.tokenize(node)
+ cell = self.__tokens__[1:-1].copy()
+ new_node = TableTree(node.tag,
+ int(node.attrib.get('colspan', '1')),
+ int(node.attrib.get('rowspan', '1')),
+ cell, *deque())
+ else:
+ new_node = TableTree(node.tag, None, None, None, *deque())
+ if parent is not None:
+ parent.children.append(new_node)
+ if node.tag != 'td':
+ for n in node.getchildren():
+ self.load_html_tree(n, new_node)
+ if parent is None:
+ return new_node
+
+ def evaluate(self, pred, true):
+ ''' Computes TEDS score between the prediction and the ground truth of a
+ given sample
+ '''
+ if (not pred) or (not true):
+ return 0.0
+ parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
+ pred = html.fromstring(pred, parser=parser)
+ true = html.fromstring(true, parser=parser)
+ if pred.xpath('body/table') and true.xpath('body/table'):
+ pred = pred.xpath('body/table')[0]
+ true = true.xpath('body/table')[0]
+ if self.ignore_nodes:
+ etree.strip_tags(pred, *self.ignore_nodes)
+ etree.strip_tags(true, *self.ignore_nodes)
+ n_nodes_pred = len(pred.xpath(".//*"))
+ n_nodes_true = len(true.xpath(".//*"))
+ n_nodes = max(n_nodes_pred, n_nodes_true)
+ tree_pred = self.load_html_tree(pred)
+ tree_true = self.load_html_tree(true)
+ distance = APTED(tree_pred, tree_true,
+ CustomConfig()).compute_edit_distance()
+ return 1.0 - (float(distance) / n_nodes)
+ else:
+ return 0.0
+
+ def batch_evaluate(self, pred_json, true_json):
+ ''' Computes TEDS score between the prediction and the ground truth of
+ a batch of samples
+ @params pred_json: {'FILENAME': 'HTML CODE', ...}
+ @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
+ @output: {'FILENAME': 'TEDS SCORE', ...}
+ '''
+ samples = true_json.keys()
+ if self.n_jobs == 1:
+ scores = [self.evaluate(pred_json.get(
+ filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
+ else:
+ inputs = [{'pred': pred_json.get(
+ filename, ''), 'true': true_json[filename]['html']} for filename in samples]
+ scores = parallel_process(
+ inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
+ scores = dict(zip(samples, scores))
+ return scores
+
+ def batch_evaluate_html(self, pred_htmls, true_htmls):
+ ''' Computes TEDS score between the prediction and the ground truth of
+ a batch of samples
+ '''
+ if self.n_jobs == 1:
+ scores = [self.evaluate(pred_html, true_html) for (
+ pred_html, true_html) in zip(pred_htmls, true_htmls)]
+ else:
+ inputs = [{"pred": pred_html, "true": true_html} for(
+ pred_html, true_html) in zip(pred_htmls, true_htmls)]
+
+ scores = parallel_process(
+ inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
+ return scores
+
+
+if __name__ == '__main__':
+ import json
+ import pprint
+ with open('sample_pred.json') as fp:
+ pred_json = json.load(fp)
+ with open('sample_gt.json') as fp:
+ true_json = json.load(fp)
+ teds = TEDS(n_jobs=4)
+ scores = teds.batch_evaluate(pred_json, true_json)
+ pp = pprint.PrettyPrinter()
+ pp.pprint(scores)
diff --git a/ppstructure/table/tablepyxl/__init__.py b/ppstructure/table/tablepyxl/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc0085071cf4497b01fc648e7c38f2e8d9d173d0
--- /dev/null
+++ b/ppstructure/table/tablepyxl/__init__.py
@@ -0,0 +1,13 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
\ No newline at end of file
diff --git a/ppstructure/table/tablepyxl/style.py b/ppstructure/table/tablepyxl/style.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebd794b1b47d7f9e4f9294dde7330f592d613656
--- /dev/null
+++ b/ppstructure/table/tablepyxl/style.py
@@ -0,0 +1,283 @@
+# This is where we handle translating css styles into openpyxl styles
+# and cascading those from parent to child in the dom.
+
+from openpyxl.cell import cell
+from openpyxl.styles import Font, Alignment, PatternFill, NamedStyle, Border, Side, Color
+from openpyxl.styles.fills import FILL_SOLID
+from openpyxl.styles.numbers import FORMAT_CURRENCY_USD_SIMPLE, FORMAT_PERCENTAGE
+from openpyxl.styles.colors import BLACK
+
+FORMAT_DATE_MMDDYYYY = 'mm/dd/yyyy'
+
+
+def colormap(color):
+ """
+ Convenience for looking up known colors
+ """
+ cmap = {'black': BLACK}
+ return cmap.get(color, color)
+
+
+def style_string_to_dict(style):
+ """
+ Convert css style string to a python dictionary
+ """
+ def clean_split(string, delim):
+ return (s.strip() for s in string.split(delim))
+ styles = [clean_split(s, ":") for s in style.split(";") if ":" in s]
+ return dict(styles)
+
+
+def get_side(style, name):
+ return {'border_style': style.get('border-{}-style'.format(name)),
+ 'color': colormap(style.get('border-{}-color'.format(name)))}
+
+known_styles = {}
+
+
+def style_dict_to_named_style(style_dict, number_format=None):
+ """
+ Change css style (stored in a python dictionary) to openpyxl NamedStyle
+ """
+
+ style_and_format_string = str({
+ 'style_dict': style_dict,
+ 'parent': style_dict.parent,
+ 'number_format': number_format,
+ })
+
+ if style_and_format_string not in known_styles:
+ # Font
+ font = Font(bold=style_dict.get('font-weight') == 'bold',
+ color=style_dict.get_color('color', None),
+ size=style_dict.get('font-size'))
+
+ # Alignment
+ alignment = Alignment(horizontal=style_dict.get('text-align', 'general'),
+ vertical=style_dict.get('vertical-align'),
+ wrap_text=style_dict.get('white-space', 'nowrap') == 'normal')
+
+ # Fill
+ bg_color = style_dict.get_color('background-color')
+ fg_color = style_dict.get_color('foreground-color', Color())
+ fill_type = style_dict.get('fill-type')
+ if bg_color and bg_color != 'transparent':
+ fill = PatternFill(fill_type=fill_type or FILL_SOLID,
+ start_color=bg_color,
+ end_color=fg_color)
+ else:
+ fill = PatternFill()
+
+ # Border
+ border = Border(left=Side(**get_side(style_dict, 'left')),
+ right=Side(**get_side(style_dict, 'right')),
+ top=Side(**get_side(style_dict, 'top')),
+ bottom=Side(**get_side(style_dict, 'bottom')),
+ diagonal=Side(**get_side(style_dict, 'diagonal')),
+ diagonal_direction=None,
+ outline=Side(**get_side(style_dict, 'outline')),
+ vertical=None,
+ horizontal=None)
+
+ name = 'Style {}'.format(len(known_styles) + 1)
+
+ pyxl_style = NamedStyle(name=name, font=font, fill=fill, alignment=alignment, border=border,
+ number_format=number_format)
+
+ known_styles[style_and_format_string] = pyxl_style
+
+ return known_styles[style_and_format_string]
+
+
+class StyleDict(dict):
+ """
+ It's like a dictionary, but it looks for items in the parent dictionary
+ """
+ def __init__(self, *args, **kwargs):
+ self.parent = kwargs.pop('parent', None)
+ super(StyleDict, self).__init__(*args, **kwargs)
+
+ def __getitem__(self, item):
+ if item in self:
+ return super(StyleDict, self).__getitem__(item)
+ elif self.parent:
+ return self.parent[item]
+ else:
+ raise KeyError('{} not found'.format(item))
+
+ def __hash__(self):
+ return hash(tuple([(k, self.get(k)) for k in self._keys()]))
+
+ # Yielding the keys avoids creating unnecessary data structures
+ # and happily works with both python2 and python3 where the
+ # .keys() method is a dictionary_view in python3 and a list in python2.
+ def _keys(self):
+ yielded = set()
+ for k in self.keys():
+ yielded.add(k)
+ yield k
+ if self.parent:
+ for k in self.parent._keys():
+ if k not in yielded:
+ yielded.add(k)
+ yield k
+
+ def get(self, k, d=None):
+ try:
+ return self[k]
+ except KeyError:
+ return d
+
+ def get_color(self, k, d=None):
+ """
+ Strip leading # off colors if necessary
+ """
+ color = self.get(k, d)
+ if hasattr(color, 'startswith') and color.startswith('#'):
+ color = color[1:]
+ if len(color) == 3: # Premailers reduces colors like #00ff00 to #0f0, openpyxl doesn't like that
+ color = ''.join(2 * c for c in color)
+ return color
+
+
+class Element(object):
+ """
+ Our base class for representing an html element along with a cascading style.
+ The element is created along with a parent so that the StyleDict that we store
+ can point to the parent's StyleDict.
+ """
+ def __init__(self, element, parent=None):
+ self.element = element
+ self.number_format = None
+ parent_style = parent.style_dict if parent else None
+ self.style_dict = StyleDict(style_string_to_dict(element.get('style', '')), parent=parent_style)
+ self._style_cache = None
+
+ def style(self):
+ """
+ Turn the css styles for this element into an openpyxl NamedStyle.
+ """
+ if not self._style_cache:
+ self._style_cache = style_dict_to_named_style(self.style_dict, number_format=self.number_format)
+ return self._style_cache
+
+ def get_dimension(self, dimension_key):
+ """
+ Extracts the dimension from the style dict of the Element and returns it as a float.
+ """
+ dimension = self.style_dict.get(dimension_key)
+ if dimension:
+ if dimension[-2:] in ['px', 'em', 'pt', 'in', 'cm']:
+ dimension = dimension[:-2]
+ dimension = float(dimension)
+ return dimension
+
+
+class Table(Element):
+ """
+ The concrete implementations of Elements are semantically named for the types of elements we are interested in.
+ This defines a very concrete tree structure for html tables that we expect to deal with. I prefer this compared to
+ allowing Element to have an arbitrary number of children and dealing with an abstract element tree.
+ """
+ def __init__(self, table):
+ """
+ takes an html table object (from lxml)
+ """
+ super(Table, self).__init__(table)
+ table_head = table.find('thead')
+ self.head = TableHead(table_head, parent=self) if table_head is not None else None
+ table_body = table.find('tbody')
+ self.body = TableBody(table_body if table_body is not None else table, parent=self)
+
+
+class TableHead(Element):
+ """
+ This class maps to the `` element of the html table.
+ """
+ def __init__(self, head, parent=None):
+ super(TableHead, self).__init__(head, parent=parent)
+ self.rows = [TableRow(tr, parent=self) for tr in head.findall('tr')]
+
+
+class TableBody(Element):
+ """
+ This class maps to the ` | ` element of the html table.
+ """
+ def __init__(self, body, parent=None):
+ super(TableBody, self).__init__(body, parent=parent)
+ self.rows = [TableRow(tr, parent=self) for tr in body.findall('tr')]
+
+
+class TableRow(Element):
+ """
+ This class maps to the `` element of the html table.
+ """
+ def __init__(self, tr, parent=None):
+ super(TableRow, self).__init__(tr, parent=parent)
+ self.cells = [TableCell(cell, parent=self) for cell in tr.findall('th') + tr.findall('td')]
+
+
+def element_to_string(el):
+ return _element_to_string(el).strip()
+
+
+def _element_to_string(el):
+ string = ''
+
+ for x in el.iterchildren():
+ string += '\n' + _element_to_string(x)
+
+ text = el.text.strip() if el.text else ''
+ tail = el.tail.strip() if el.tail else ''
+
+ return text + string + '\n' + tail
+
+
+class TableCell(Element):
+ """
+ This class maps to the `` element of the html table.
+ """
+ CELL_TYPES = {'TYPE_STRING', 'TYPE_FORMULA', 'TYPE_NUMERIC', 'TYPE_BOOL', 'TYPE_CURRENCY', 'TYPE_PERCENTAGE',
+ 'TYPE_NULL', 'TYPE_INLINE', 'TYPE_ERROR', 'TYPE_FORMULA_CACHE_STRING', 'TYPE_INTEGER'}
+
+ def __init__(self, cell, parent=None):
+ super(TableCell, self).__init__(cell, parent=parent)
+ self.value = element_to_string(cell)
+ self.number_format = self.get_number_format()
+
+ def data_type(self):
+ cell_types = self.CELL_TYPES & set(self.element.get('class', '').split())
+ if cell_types:
+ if 'TYPE_FORMULA' in cell_types:
+ # Make sure TYPE_FORMULA takes precedence over the other classes in the set.
+ cell_type = 'TYPE_FORMULA'
+ elif cell_types & {'TYPE_CURRENCY', 'TYPE_INTEGER', 'TYPE_PERCENTAGE'}:
+ cell_type = 'TYPE_NUMERIC'
+ else:
+ cell_type = cell_types.pop()
+ else:
+ cell_type = 'TYPE_STRING'
+ return getattr(cell, cell_type)
+
+ def get_number_format(self):
+ if 'TYPE_CURRENCY' in self.element.get('class', '').split():
+ return FORMAT_CURRENCY_USD_SIMPLE
+ if 'TYPE_INTEGER' in self.element.get('class', '').split():
+ return '#,##0'
+ if 'TYPE_PERCENTAGE' in self.element.get('class', '').split():
+ return FORMAT_PERCENTAGE
+ if 'TYPE_DATE' in self.element.get('class', '').split():
+ return FORMAT_DATE_MMDDYYYY
+ if self.data_type() == cell.TYPE_NUMERIC:
+ try:
+ int(self.value)
+ except ValueError:
+ return '#,##0.##'
+ else:
+ return '#,##0'
+
+ def format(self, cell):
+ cell.style = self.style()
+ data_type = self.data_type()
+ if data_type:
+ cell.data_type = data_type
\ No newline at end of file
diff --git a/ppstructure/table/tablepyxl/tablepyxl.py b/ppstructure/table/tablepyxl/tablepyxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba3cc0fc499fccd93ffe3993a99296bc6603ed8a
--- /dev/null
+++ b/ppstructure/table/tablepyxl/tablepyxl.py
@@ -0,0 +1,118 @@
+# Do imports like python3 so our package works for 2 and 3
+from __future__ import absolute_import
+
+from lxml import html
+from openpyxl import Workbook
+from openpyxl.utils import get_column_letter
+from premailer import Premailer
+from tablepyxl.style import Table
+
+
+def string_to_int(s):
+ if s.isdigit():
+ return int(s)
+ return 0
+
+
+def get_Tables(doc):
+ tree = html.fromstring(doc)
+ comments = tree.xpath('//comment()')
+ for comment in comments:
+ comment.drop_tag()
+ return [Table(table) for table in tree.xpath('//table')]
+
+
+def write_rows(worksheet, elem, row, column=1):
+ """
+ Writes every tr child element of elem to a row in the worksheet
+ returns the next row after all rows are written
+ """
+ from openpyxl.cell.cell import MergedCell
+
+ initial_column = column
+ for table_row in elem.rows:
+ for table_cell in table_row.cells:
+ cell = worksheet.cell(row=row, column=column)
+ while isinstance(cell, MergedCell):
+ column += 1
+ cell = worksheet.cell(row=row, column=column)
+
+ colspan = string_to_int(table_cell.element.get("colspan", "1"))
+ rowspan = string_to_int(table_cell.element.get("rowspan", "1"))
+ if rowspan > 1 or colspan > 1:
+ worksheet.merge_cells(start_row=row, start_column=column,
+ end_row=row + rowspan - 1, end_column=column + colspan - 1)
+
+ cell.value = table_cell.value
+ table_cell.format(cell)
+ min_width = table_cell.get_dimension('min-width')
+ max_width = table_cell.get_dimension('max-width')
+
+ if colspan == 1:
+ # Initially, when iterating for the first time through the loop, the width of all the cells is None.
+ # As we start filling in contents, the initial width of the cell (which can be retrieved by:
+ # worksheet.column_dimensions[get_column_letter(column)].width) is equal to the width of the previous
+ # cell in the same column (i.e. width of A2 = width of A1)
+ width = max(worksheet.column_dimensions[get_column_letter(column)].width or 0, len(table_cell.value) + 2)
+ if max_width and width > max_width:
+ width = max_width
+ elif min_width and width < min_width:
+ width = min_width
+ worksheet.column_dimensions[get_column_letter(column)].width = width
+ column += colspan
+ row += 1
+ column = initial_column
+ return row
+
+
+def table_to_sheet(table, wb):
+ """
+ Takes a table and workbook and writes the table to a new sheet.
+ The sheet title will be the same as the table attribute name.
+ """
+ ws = wb.create_sheet(title=table.element.get('name'))
+ insert_table(table, ws, 1, 1)
+
+
+def document_to_workbook(doc, wb=None, base_url=None):
+ """
+ Takes a string representation of an html document and writes one sheet for
+ every table in the document.
+ The workbook is returned
+ """
+ if not wb:
+ wb = Workbook()
+ wb.remove(wb.active)
+
+ inline_styles_doc = Premailer(doc, base_url=base_url, remove_classes=False).transform()
+ tables = get_Tables(inline_styles_doc)
+
+ for table in tables:
+ table_to_sheet(table, wb)
+
+ return wb
+
+
+def document_to_xl(doc, filename, base_url=None):
+ """
+ Takes a string representation of an html document and writes one sheet for
+ every table in the document. The workbook is written out to a file called filename
+ """
+ wb = document_to_workbook(doc, base_url=base_url)
+ wb.save(filename)
+
+
+def insert_table(table, worksheet, column, row):
+ if table.head:
+ row = write_rows(worksheet, table.head, row, column)
+ if table.body:
+ row = write_rows(worksheet, table.body, row, column)
+
+
+def insert_table_at_cell(table, cell):
+ """
+ Inserts a table at the location of an openpyxl Cell object.
+ """
+ ws = cell.parent
+ column, row = cell.column, cell.row
+ insert_table(table, ws, column, row)
\ No newline at end of file
diff --git a/ppstructure/utility.py b/ppstructure/utility.py
new file mode 100644
index 0000000000000000000000000000000000000000..8112b9efd2155d69784ebc9915d9c3ec30e94f9c
--- /dev/null
+++ b/ppstructure/utility.py
@@ -0,0 +1,59 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from PIL import Image
+import numpy as np
+from tools.infer.utility import draw_ocr_box_txt, init_args as infer_args
+
+
+def init_args():
+ parser = infer_args()
+
+ # params for output
+ parser.add_argument("--output", type=str, default='./output/table')
+ # params for table structure
+ parser.add_argument("--structure_max_len", type=int, default=488)
+ parser.add_argument("--structure_max_text_length", type=int, default=100)
+ parser.add_argument("--structure_max_elem_length", type=int, default=800)
+ parser.add_argument("--structure_max_cell_num", type=int, default=500)
+ parser.add_argument("--structure_model_dir", type=str)
+ parser.add_argument("--structure_char_type", type=str, default='en')
+ parser.add_argument("--structure_char_dict_path", type=str, default="../ppocr/utils/dict/table_structure_dict.txt")
+
+ # params for layout detector
+ parser.add_argument("--layout_model_dir", type=str)
+ return parser
+
+
+def parse_args():
+ parser = init_args()
+ return parser.parse_args()
+
+
+def draw_result(image, result, font_path):
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+ boxes, txts, scores = [], [], []
+ for region in result:
+ if region['type'] == 'Table':
+ pass
+ elif region['type'] == 'Figure':
+ pass
+ else:
+ for box, rec_res in zip(region['res'][0], region['res'][1]):
+ boxes.append(np.array(box).reshape(-1, 2))
+ txts.append(rec_res[0])
+ scores.append(rec_res[1])
+ im_show = draw_ocr_box_txt(image, boxes, txts, scores, font_path=font_path,drop_score=0)
+ return im_show
\ No newline at end of file
diff --git a/tools/infer/benchmark_utils.py b/tools/infer/benchmark_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a241d063368d19567e253bf1dada09801d468bc
--- /dev/null
+++ b/tools/infer/benchmark_utils.py
@@ -0,0 +1,232 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import time
+import logging
+
+import paddle
+import paddle.inference as paddle_infer
+
+from pathlib import Path
+
+CUR_DIR = os.path.dirname(os.path.abspath(__file__))
+
+
+class PaddleInferBenchmark(object):
+ def __init__(self,
+ config,
+ model_info: dict={},
+ data_info: dict={},
+ perf_info: dict={},
+ resource_info: dict={},
+ save_log_path: str="",
+ **kwargs):
+ """
+ Construct PaddleInferBenchmark Class to format logs.
+ args:
+ config(paddle.inference.Config): paddle inference config
+ model_info(dict): basic model info
+ {'model_name': 'resnet50'
+ 'precision': 'fp32'}
+ data_info(dict): input data info
+ {'batch_size': 1
+ 'shape': '3,224,224'
+ 'data_num': 1000}
+ perf_info(dict): performance result
+ {'preprocess_time_s': 1.0
+ 'inference_time_s': 2.0
+ 'postprocess_time_s': 1.0
+ 'total_time_s': 4.0}
+ resource_info(dict):
+ cpu and gpu resources
+ {'cpu_rss': 100
+ 'gpu_rss': 100
+ 'gpu_util': 60}
+ """
+ # PaddleInferBenchmark Log Version
+ self.log_version = 1.0
+
+ # Paddle Version
+ self.paddle_version = paddle.__version__
+ self.paddle_commit = paddle.__git_commit__
+ paddle_infer_info = paddle_infer.get_version()
+ self.paddle_branch = paddle_infer_info.strip().split(': ')[-1]
+
+ # model info
+ self.model_info = model_info
+
+ # data info
+ self.data_info = data_info
+
+ # perf info
+ self.perf_info = perf_info
+
+ try:
+ self.model_name = model_info['model_name']
+ self.precision = model_info['precision']
+
+ self.batch_size = data_info['batch_size']
+ self.shape = data_info['shape']
+ self.data_num = data_info['data_num']
+
+ self.preprocess_time_s = round(perf_info['preprocess_time_s'], 4)
+ self.inference_time_s = round(perf_info['inference_time_s'], 4)
+ self.postprocess_time_s = round(perf_info['postprocess_time_s'], 4)
+ self.total_time_s = round(perf_info['total_time_s'], 4)
+ except:
+ self.print_help()
+ raise ValueError(
+ "Set argument wrong, please check input argument and its type")
+
+ # conf info
+ self.config_status = self.parse_config(config)
+ self.save_log_path = save_log_path
+ # mem info
+ if isinstance(resource_info, dict):
+ self.cpu_rss_mb = int(resource_info.get('cpu_rss_mb', 0))
+ self.gpu_rss_mb = int(resource_info.get('gpu_rss_mb', 0))
+ self.gpu_util = round(resource_info.get('gpu_util', 0), 2)
+ else:
+ self.cpu_rss_mb = 0
+ self.gpu_rss_mb = 0
+ self.gpu_util = 0
+
+ # init benchmark logger
+ self.benchmark_logger()
+
+ def benchmark_logger(self):
+ """
+ benchmark logger
+ """
+ # Init logger
+ FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ log_output = f"{self.save_log_path}/{self.model_name}.log"
+ Path(f"{self.save_log_path}").mkdir(parents=True, exist_ok=True)
+ logging.basicConfig(
+ level=logging.INFO,
+ format=FORMAT,
+ handlers=[
+ logging.FileHandler(
+ filename=log_output, mode='w'),
+ logging.StreamHandler(),
+ ])
+ self.logger = logging.getLogger(__name__)
+ self.logger.info(
+ f"Paddle Inference benchmark log will be saved to {log_output}")
+
+ def parse_config(self, config) -> dict:
+ """
+ parse paddle predictor config
+ args:
+ config(paddle.inference.Config): paddle inference config
+ return:
+ config_status(dict): dict style config info
+ """
+ config_status = {}
+ config_status['runtime_device'] = "gpu" if config.use_gpu() else "cpu"
+ config_status['ir_optim'] = config.ir_optim()
+ config_status['enable_tensorrt'] = config.tensorrt_engine_enabled()
+ config_status['precision'] = self.precision
+ config_status['enable_mkldnn'] = config.mkldnn_enabled()
+ config_status[
+ 'cpu_math_library_num_threads'] = config.cpu_math_library_num_threads(
+ )
+ return config_status
+
+ def report(self, identifier=None):
+ """
+ print log report
+ args:
+ identifier(string): identify log
+ """
+ if identifier:
+ identifier = f"[{identifier}]"
+ else:
+ identifier = ""
+
+ self.logger.info("\n")
+ self.logger.info(
+ "---------------------- Paddle info ----------------------")
+ self.logger.info(f"{identifier} paddle_version: {self.paddle_version}")
+ self.logger.info(f"{identifier} paddle_commit: {self.paddle_commit}")
+ self.logger.info(f"{identifier} paddle_branch: {self.paddle_branch}")
+ self.logger.info(f"{identifier} log_api_version: {self.log_version}")
+ self.logger.info(
+ "----------------------- Conf info -----------------------")
+ self.logger.info(
+ f"{identifier} runtime_device: {self.config_status['runtime_device']}"
+ )
+ self.logger.info(
+ f"{identifier} ir_optim: {self.config_status['ir_optim']}")
+ self.logger.info(f"{identifier} enable_memory_optim: {True}")
+ self.logger.info(
+ f"{identifier} enable_tensorrt: {self.config_status['enable_tensorrt']}"
+ )
+ self.logger.info(
+ f"{identifier} enable_mkldnn: {self.config_status['enable_mkldnn']}")
+ self.logger.info(
+ f"{identifier} cpu_math_library_num_threads: {self.config_status['cpu_math_library_num_threads']}"
+ )
+ self.logger.info(
+ "----------------------- Model info ----------------------")
+ self.logger.info(f"{identifier} model_name: {self.model_name}")
+ self.logger.info(f"{identifier} precision: {self.precision}")
+ self.logger.info(
+ "----------------------- Data info -----------------------")
+ self.logger.info(f"{identifier} batch_size: {self.batch_size}")
+ self.logger.info(f"{identifier} input_shape: {self.shape}")
+ self.logger.info(f"{identifier} data_num: {self.data_num}")
+ self.logger.info(
+ "----------------------- Perf info -----------------------")
+ self.logger.info(
+ f"{identifier} cpu_rss(MB): {self.cpu_rss_mb}, gpu_rss(MB): {self.gpu_rss_mb}, gpu_util: {self.gpu_util}%"
+ )
+ self.logger.info(
+ f"{identifier} total time spent(s): {self.total_time_s}")
+ self.logger.info(
+ f"{identifier} preprocess_time(ms): {round(self.preprocess_time_s*1000, 1)}, inference_time(ms): {round(self.inference_time_s*1000, 1)}, postprocess_time(ms): {round(self.postprocess_time_s*1000, 1)}"
+ )
+
+ def print_help(self):
+ """
+ print function help
+ """
+ print("""Usage:
+ ==== Print inference benchmark logs. ====
+ config = paddle.inference.Config()
+ model_info = {'model_name': 'resnet50'
+ 'precision': 'fp32'}
+ data_info = {'batch_size': 1
+ 'shape': '3,224,224'
+ 'data_num': 1000}
+ perf_info = {'preprocess_time_s': 1.0
+ 'inference_time_s': 2.0
+ 'postprocess_time_s': 1.0
+ 'total_time_s': 4.0}
+ resource_info = {'cpu_rss_mb': 100
+ 'gpu_rss_mb': 100
+ 'gpu_util': 60}
+ log = PaddleInferBenchmark(config, model_info, data_info, perf_info, resource_info)
+ log('Test')
+ """)
+
+ def __call__(self, identifier=None):
+ """
+ __call__
+ args:
+ identifier(string): identify log
+ """
+ self.report(identifier)
diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py
index d2592c6c95b0f466ea3ad5b45a35781282c9a492..0037b226df8e1de8edbdb7668e349925a942e8b9 100755
--- a/tools/infer/predict_cls.py
+++ b/tools/infer/predict_cls.py
@@ -45,9 +45,11 @@ class TextClassifier(object):
"label_list": args.label_list,
}
self.postprocess_op = build_post_process(postprocess_params)
- self.predictor, self.input_tensor, self.output_tensors = \
+ self.predictor, self.input_tensor, self.output_tensors, _ = \
utility.create_predictor(args, 'cls', logger)
+ self.cls_times = utility.Timer()
+
def resize_norm_img(self, img):
imgC, imgH, imgW = self.cls_image_shape
h = img.shape[0]
@@ -83,7 +85,9 @@ class TextClassifier(object):
cls_res = [['', 0.0]] * img_num
batch_num = self.cls_batch_num
elapse = 0
+ self.cls_times.total_time.start()
for beg_img_no in range(0, img_num, batch_num):
+
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
@@ -91,6 +95,7 @@ class TextClassifier(object):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
+ self.cls_times.preprocess_time.start()
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :]
@@ -98,11 +103,17 @@ class TextClassifier(object):
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
starttime = time.time()
+ self.cls_times.preprocess_time.end()
+ self.cls_times.inference_time.start()
+
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run()
prob_out = self.output_tensors[0].copy_to_cpu()
+ self.cls_times.inference_time.end()
+ self.cls_times.postprocess_time.start()
self.predictor.try_shrink_memory()
cls_result = self.postprocess_op(prob_out)
+ self.cls_times.postprocess_time.end()
elapse += time.time() - starttime
for rno in range(len(cls_result)):
label, score = cls_result[rno]
@@ -110,6 +121,9 @@ class TextClassifier(object):
if '180' in label and score > self.cls_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate(
img_list[indices[beg_img_no + rno]], 1)
+ self.cls_times.total_time.end()
+ self.cls_times.img_num += img_num
+ elapse = self.cls_times.total_time.value()
return img_list, cls_res, elapse
@@ -141,8 +155,9 @@ def main(args):
for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
cls_res[ino]))
- logger.info("Total predict time for {} images, cost: {:.3f}".format(
- len(img_list), predict_time))
+ logger.info(
+ "The predict time about text angle classify module is as follows: ")
+ text_classifier.cls_times.info(average=False)
if __name__ == "__main__":
diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py
index 59bb49f90abb198933b91f222febad7a416018e8..baa89be130084d98628656fe4e309728a0e9f661 100755
--- a/tools/infer/predict_det.py
+++ b/tools/infer/predict_det.py
@@ -31,6 +31,8 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
+import tools.infer.benchmark_utils as benchmark_utils
+
logger = get_logger()
@@ -41,7 +43,7 @@ class TextDetector(object):
pre_process_list = [{
'DetResizeForTest': {
'limit_side_len': args.det_limit_side_len,
- 'limit_type': args.det_limit_type
+ 'limit_type': args.det_limit_type,
}
}, {
'NormalizeImage': {
@@ -95,9 +97,10 @@ class TextDetector(object):
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
- self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor(
- args, 'det', logger) # paddle.jit.load(args.det_model_dir)
- # self.predictor.eval()
+ self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
+ args, 'det', logger)
+
+ self.det_times = utility.Timer()
def order_points_clockwise(self, pts):
"""
@@ -155,6 +158,8 @@ class TextDetector(object):
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
+ self.det_times.total_time.start()
+ self.det_times.preprocess_time.start()
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
@@ -162,7 +167,9 @@ class TextDetector(object):
img = np.expand_dims(img, axis=0)
shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy()
- starttime = time.time()
+
+ self.det_times.preprocess_time.end()
+ self.det_times.inference_time.start()
self.input_tensor.copy_from_cpu(img)
self.predictor.run()
@@ -170,6 +177,7 @@ class TextDetector(object):
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
+ self.det_times.inference_time.end()
preds = {}
if self.det_algorithm == "EAST":
@@ -184,6 +192,9 @@ class TextDetector(object):
preds['maps'] = outputs[0]
else:
raise NotImplementedError
+
+ self.det_times.postprocess_time.start()
+
self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points']
@@ -191,8 +202,11 @@ class TextDetector(object):
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
- elapse = time.time() - starttime
- return dt_boxes, elapse
+
+ self.det_times.postprocess_time.end()
+ self.det_times.total_time.end()
+ self.det_times.img_num += 1
+ return dt_boxes, self.det_times.total_time.value()
if __name__ == "__main__":
@@ -202,6 +216,13 @@ if __name__ == "__main__":
count = 0
total_time = 0
draw_img_save = "./inference_results"
+ cpu_mem, gpu_mem, gpu_util = 0, 0, 0
+
+ # warmup 10 times
+ fake_img = np.random.uniform(-1, 1, [640, 640, 3]).astype(np.float32)
+ for i in range(10):
+ dt_boxes, _ = text_detector(fake_img)
+
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
for image_file in image_file_list:
@@ -211,16 +232,56 @@ if __name__ == "__main__":
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
- dt_boxes, elapse = text_detector(img)
+ st = time.time()
+ dt_boxes, _ = text_detector(img)
+ elapse = time.time() - st
if count > 0:
total_time += elapse
count += 1
+
+ if args.benchmark:
+ cm, gm, gu = utility.get_current_memory_mb(0)
+ cpu_mem += cm
+ gpu_mem += gm
+ gpu_util += gu
+
logger.info("Predict time of {}: {}".format(image_file, elapse))
src_im = utility.draw_text_det_res(dt_boxes, image_file)
img_name_pure = os.path.split(image_file)[-1]
img_path = os.path.join(draw_img_save,
"det_res_{}".format(img_name_pure))
- cv2.imwrite(img_path, src_im)
+
logger.info("The visualized image saved in {}".format(img_path))
- if count > 1:
- logger.info("Avg Time: {}".format(total_time / (count - 1)))
+ # print the information about memory and time-spent
+ if args.benchmark:
+ mems = {
+ 'cpu_rss_mb': cpu_mem / count,
+ 'gpu_rss_mb': gpu_mem / count,
+ 'gpu_util': gpu_util * 100 / count
+ }
+ else:
+ mems = None
+ logger.info("The predict time about detection module is as follows: ")
+ det_time_dict = text_detector.det_times.report(average=True)
+ det_model_name = args.det_model_dir
+
+ if args.benchmark:
+ # construct log information
+ model_info = {
+ 'model_name': args.det_model_dir.split('/')[-1],
+ 'precision': args.precision
+ }
+ data_info = {
+ 'batch_size': 1,
+ 'shape': 'dynamic_shape',
+ 'data_num': det_time_dict['img_num']
+ }
+ perf_info = {
+ 'preprocess_time_s': det_time_dict['preprocess_time'],
+ 'inference_time_s': det_time_dict['inference_time'],
+ 'postprocess_time_s': det_time_dict['postprocess_time'],
+ 'total_time_s': det_time_dict['total_time']
+ }
+ benchmark_log = benchmark_utils.PaddleInferBenchmark(
+ text_detector.config, model_info, data_info, perf_info, mems)
+ benchmark_log("Det")
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index 24388026b8f395427c93e285ed550446e3aa9b9c..2eeb39b2a0bff15241ea7762b4981e4daaada096 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -28,6 +28,7 @@ import traceback
import paddle
import tools.infer.utility as utility
+import tools.infer.benchmark_utils as benchmark_utils
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
@@ -41,7 +42,6 @@ class TextRecognizer(object):
self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
- self.max_text_length = args.max_text_length
postprocess_params = {
'name': 'CTCLabelDecode',
"character_type": args.rec_char_type,
@@ -63,9 +63,11 @@ class TextRecognizer(object):
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params)
- self.predictor, self.input_tensor, self.output_tensors = \
+ self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
+ self.rec_times = utility.Timer()
+
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
assert imgC == img.shape[2]
@@ -166,17 +168,15 @@ class TextRecognizer(object):
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
-
- # rec_res = []
+ self.rec_times.total_time.start()
rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num
- elapse = 0
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
+ self.rec_times.preprocess_time.start()
for ino in range(beg_img_no, end_img_no):
- # h, w = img_list[ino].shape[0:2]
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
@@ -187,9 +187,8 @@ class TextRecognizer(object):
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
else:
- norm_img = self.process_image_srn(img_list[indices[ino]],
- self.rec_image_shape, 8,
- self.max_text_length)
+ norm_img = self.process_image_srn(
+ img_list[indices[ino]], self.rec_image_shape, 8, 25)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
@@ -203,7 +202,6 @@ class TextRecognizer(object):
norm_img_batch = norm_img_batch.copy()
if self.rec_algorithm == "SRN":
- starttime = time.time()
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(
@@ -218,19 +216,23 @@ class TextRecognizer(object):
gsrm_slf_attn_bias1_list,
gsrm_slf_attn_bias2_list,
]
+ self.rec_times.preprocess_time.end()
+ self.rec_times.inference_time.start()
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[
i])
input_tensor.copy_from_cpu(inputs[i])
self.predictor.run()
+ self.rec_times.inference_time.end()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = {"predict": outputs[2]}
else:
- starttime = time.time()
+ self.rec_times.preprocess_time.end()
+ self.rec_times.inference_time.start()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run()
@@ -239,22 +241,31 @@ class TextRecognizer(object):
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = outputs[0]
- self.predictor.try_shrink_memory()
+ self.rec_times.inference_time.end()
+ self.rec_times.postprocess_time.start()
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
- elapse += time.time() - starttime
- return rec_res, elapse
+ self.rec_times.postprocess_time.end()
+ self.rec_times.img_num += int(norm_img_batch.shape[0])
+ self.rec_times.total_time.end()
+ return rec_res, self.rec_times.total_time.value()
def main(args):
image_file_list = get_image_file_list(args.image_dir)
text_recognizer = TextRecognizer(args)
- total_run_time = 0.0
- total_images_num = 0
valid_image_file_list = []
img_list = []
- for idx, image_file in enumerate(image_file_list):
+ cpu_mem, gpu_mem, gpu_util = 0, 0, 0
+ count = 0
+
+ # warmup 10 times
+ fake_img = np.random.uniform(-1, 1, [1, 32, 320, 3]).astype(np.float32)
+ for i in range(10):
+ dt_boxes, _ = text_recognizer(fake_img)
+
+ for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
@@ -263,29 +274,54 @@ def main(args):
continue
valid_image_file_list.append(image_file)
img_list.append(img)
- if len(img_list) >= args.rec_batch_num or idx == len(
- image_file_list) - 1:
- try:
- rec_res, predict_time = text_recognizer(img_list)
- total_run_time += predict_time
- except:
- logger.info(traceback.format_exc())
- logger.info(
- "ERROR!!!! \n"
- "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
- "If your model has tps module: "
- "TPS does not support variable shape.\n"
- "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
- )
- exit()
- for ino in range(len(img_list)):
- logger.info("Predicts of {}:{}".format(valid_image_file_list[
- ino], rec_res[ino]))
- total_images_num += len(valid_image_file_list)
- valid_image_file_list = []
- img_list = []
- logger.info("Total predict time for {} images, cost: {:.3f}".format(
- total_images_num, total_run_time))
+ try:
+ rec_res, _ = text_recognizer(img_list)
+ if args.benchmark:
+ cm, gm, gu = utility.get_current_memory_mb(0)
+ cpu_mem += cm
+ gpu_mem += gm
+ gpu_util += gu
+ count += 1
+
+ except Exception as E:
+ logger.info(traceback.format_exc())
+ logger.info(E)
+ exit()
+ for ino in range(len(img_list)):
+ logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
+ rec_res[ino]))
+ if args.benchmark:
+ mems = {
+ 'cpu_rss_mb': cpu_mem / count,
+ 'gpu_rss_mb': gpu_mem / count,
+ 'gpu_util': gpu_util * 100 / count
+ }
+ else:
+ mems = None
+ logger.info("The predict time about recognizer module is as follows: ")
+ rec_time_dict = text_recognizer.rec_times.report(average=True)
+ rec_model_name = args.rec_model_dir
+
+ if args.benchmark:
+ # construct log information
+ model_info = {
+ 'model_name': args.rec_model_dir.split('/')[-1],
+ 'precision': args.precision
+ }
+ data_info = {
+ 'batch_size': args.rec_batch_num,
+ 'shape': 'dynamic_shape',
+ 'data_num': rec_time_dict['img_num']
+ }
+ perf_info = {
+ 'preprocess_time_s': rec_time_dict['preprocess_time'],
+ 'inference_time_s': rec_time_dict['inference_time'],
+ 'postprocess_time_s': rec_time_dict['postprocess_time'],
+ 'total_time_s': rec_time_dict['total_time']
+ }
+ benchmark_log = benchmark_utils.PaddleInferBenchmark(
+ text_recognizer.config, model_info, data_info, perf_info, mems)
+ benchmark_log("Rec")
if __name__ == "__main__":
diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py
index 78f5a4729918e33e174705e1b9b0f6e4c27699c6..58a363e3e8fa852ce37cd5a44a19e460da00c2bc 100755
--- a/tools/infer/predict_system.py
+++ b/tools/infer/predict_system.py
@@ -13,7 +13,6 @@
# limitations under the License.
import os
import sys
-import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
@@ -32,8 +31,8 @@ import tools.infer.predict_det as predict_det
import tools.infer.predict_cls as predict_cls
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
-from tools.infer.utility import draw_ocr_box_txt
-
+from tools.infer.utility import draw_ocr_box_txt, get_current_memory_mb
+import tools.infer.benchmark_utils as benchmark_utils
logger = get_logger()
@@ -88,7 +87,8 @@ class TextSystem(object):
def __call__(self, img, cls=True):
ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img)
- logger.info("dt_boxes num : {}, elapse : {}".format(
+
+ logger.debug("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse))
if dt_boxes is None:
return None, None
@@ -103,11 +103,11 @@ class TextSystem(object):
if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list)
- logger.info("cls num : {}, elapse : {}".format(
+ logger.debug("cls num : {}, elapse : {}".format(
len(img_crop_list), elapse))
rec_res, elapse = self.text_recognizer(img_crop_list)
- logger.info("rec_res num : {}, elapse : {}".format(
+ logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse))
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
filter_boxes, filter_rec_res = [], []
@@ -142,23 +142,34 @@ def sorted_boxes(dt_boxes):
def main(args):
image_file_list = get_image_file_list(args.image_dir)
- image_file_list = image_file_list[args.process_id::args.total_process_num]
text_sys = TextSystem(args)
is_visualize = True
font_path = args.vis_font_path
drop_score = args.drop_score
- for image_file in image_file_list:
+ total_time = 0
+ cpu_mem, gpu_mem, gpu_util = 0, 0, 0
+ _st = time.time()
+ count = 0
+ for idx, image_file in enumerate(image_file_list):
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
- logger.info("error in loading image:{}".format(image_file))
+ logger.error("error in loading image:{}".format(image_file))
continue
starttime = time.time()
dt_boxes, rec_res = text_sys(img)
elapse = time.time() - starttime
- logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
+ total_time += elapse
+ if args.benchmark and idx % 20 == 0:
+ cm, gm, gu = get_current_memory_mb(0)
+ cpu_mem += cm
+ gpu_mem += gm
+ gpu_util += gu
+ count += 1
+ logger.info(
+ str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse))
for text, score in rec_res:
logger.info("{}, {:.3f}".format(text, score))
@@ -178,26 +189,74 @@ def main(args):
draw_img_save = "./inference_results/"
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
+ if flag:
+ image_file = image_file[:-3] + "png"
cv2.imwrite(
os.path.join(draw_img_save, os.path.basename(image_file)),
draw_img[:, :, ::-1])
logger.info("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file))))
+ logger.info("The predict total time is {}".format(time.time() - _st))
+ logger.info("\nThe predict total time is {}".format(total_time))
-if __name__ == "__main__":
- args = utility.parse_args()
- if args.use_mp:
- p_list = []
- total_process_num = args.total_process_num
- for process_id in range(total_process_num):
- cmd = [sys.executable, "-u"] + sys.argv + [
- "--process_id={}".format(process_id),
- "--use_mp={}".format(False)
- ]
- p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
- p_list.append(p)
- for p in p_list:
- p.wait()
+ img_num = text_sys.text_detector.det_times.img_num
+ if args.benchmark:
+ mems = {
+ 'cpu_rss_mb': cpu_mem / count,
+ 'gpu_rss_mb': gpu_mem / count,
+ 'gpu_util': gpu_util * 100 / count
+ }
else:
- main(args)
+ mems = None
+ det_time_dict = text_sys.text_detector.det_times.report(average=True)
+ rec_time_dict = text_sys.text_recognizer.rec_times.report(average=True)
+ det_model_name = args.det_model_dir
+ rec_model_name = args.rec_model_dir
+
+ # construct det log information
+ model_info = {
+ 'model_name': args.det_model_dir.split('/')[-1],
+ 'precision': args.precision
+ }
+ data_info = {
+ 'batch_size': 1,
+ 'shape': 'dynamic_shape',
+ 'data_num': det_time_dict['img_num']
+ }
+ perf_info = {
+ 'preprocess_time_s': det_time_dict['preprocess_time'],
+ 'inference_time_s': det_time_dict['inference_time'],
+ 'postprocess_time_s': det_time_dict['postprocess_time'],
+ 'total_time_s': det_time_dict['total_time']
+ }
+
+ benchmark_log = benchmark_utils.PaddleInferBenchmark(
+ text_sys.text_detector.config, model_info, data_info, perf_info, mems,
+ args.save_log_path)
+ benchmark_log("Det")
+
+ # construct rec log information
+ model_info = {
+ 'model_name': args.rec_model_dir.split('/')[-1],
+ 'precision': args.precision
+ }
+ data_info = {
+ 'batch_size': args.rec_batch_num,
+ 'shape': 'dynamic_shape',
+ 'data_num': rec_time_dict['img_num']
+ }
+ perf_info = {
+ 'preprocess_time_s': rec_time_dict['preprocess_time'],
+ 'inference_time_s': rec_time_dict['inference_time'],
+ 'postprocess_time_s': rec_time_dict['postprocess_time'],
+ 'total_time_s': rec_time_dict['total_time']
+ }
+ benchmark_log = benchmark_utils.PaddleInferBenchmark(
+ text_sys.text_recognizer.config, model_info, data_info, perf_info, mems,
+ args.save_log_path)
+ benchmark_log("Rec")
+
+
+if __name__ == "__main__":
+ main(utility.parse_args())
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index 3f0ff2ff64bff2c2e70be37a95b5449deaa90046..9210c45783029996f8d6fd105c8413d01c768806 100755
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -37,7 +37,7 @@ def init_args():
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
- parser.add_argument("--use_fp16", type=str2bool, default=False)
+ parser.add_argument("--precision", type=str, default="fp32")
parser.add_argument("--gpu_mem", type=int, default=500)
# params for text detector
@@ -109,6 +109,11 @@ def init_args():
parser.add_argument("--use_mp", type=str2bool, default=False)
parser.add_argument("--total_process_num", type=int, default=1)
parser.add_argument("--process_id", type=int, default=0)
+
+ parser.add_argument("--benchmark", type=bool, default=False)
+ parser.add_argument("--save_log_path", type=str, default="./log_output/")
+
+ parser.add_argument("--show_log", type=str2bool, default=True)
return parser
@@ -118,6 +123,76 @@ def parse_args():
return parser.parse_args()
+class Times(object):
+ def __init__(self):
+ self.time = 0.
+ self.st = 0.
+ self.et = 0.
+
+ def start(self):
+ self.st = time.time()
+
+ def end(self, accumulative=True):
+ self.et = time.time()
+ if accumulative:
+ self.time += self.et - self.st
+ else:
+ self.time = self.et - self.st
+
+ def reset(self):
+ self.time = 0.
+ self.st = 0.
+ self.et = 0.
+
+ def value(self):
+ return round(self.time, 4)
+
+
+class Timer(Times):
+ def __init__(self):
+ super(Timer, self).__init__()
+ self.total_time = Times()
+ self.preprocess_time = Times()
+ self.inference_time = Times()
+ self.postprocess_time = Times()
+ self.img_num = 0
+
+ def info(self, average=False):
+ logger.info("----------------------- Perf info -----------------------")
+ logger.info("total_time: {}, img_num: {}".format(self.total_time.value(
+ ), self.img_num))
+ preprocess_time = round(self.preprocess_time.value() / self.img_num,
+ 4) if average else self.preprocess_time.value()
+ postprocess_time = round(
+ self.postprocess_time.value() / self.img_num,
+ 4) if average else self.postprocess_time.value()
+ inference_time = round(self.inference_time.value() / self.img_num,
+ 4) if average else self.inference_time.value()
+
+ average_latency = self.total_time.value() / self.img_num
+ logger.info("average_latency(ms): {:.2f}, QPS: {:2f}".format(
+ average_latency * 1000, 1 / average_latency))
+ logger.info(
+ "preprocess_latency(ms): {:.2f}, inference_latency(ms): {:.2f}, postprocess_latency(ms): {:.2f}".
+ format(preprocess_time * 1000, inference_time * 1000,
+ postprocess_time * 1000))
+
+ def report(self, average=False):
+ dic = {}
+ dic['preprocess_time'] = round(
+ self.preprocess_time.value() / self.img_num,
+ 4) if average else self.preprocess_time.value()
+ dic['postprocess_time'] = round(
+ self.postprocess_time.value() / self.img_num,
+ 4) if average else self.postprocess_time.value()
+ dic['inference_time'] = round(
+ self.inference_time.value() / self.img_num,
+ 4) if average else self.inference_time.value()
+ dic['img_num'] = self.img_num
+ dic['total_time'] = round(self.total_time.value(), 4)
+ return dic
+
+
def create_predictor(args, mode, logger):
if mode == "det":
model_dir = args.det_model_dir
@@ -125,6 +200,8 @@ def create_predictor(args, mode, logger):
model_dir = args.cls_model_dir
elif mode == 'rec':
model_dir = args.rec_model_dir
+ elif mode == 'structure':
+ model_dir = args.structure_model_dir
else:
model_dir = args.e2e_model_dir
@@ -142,6 +219,16 @@ def create_predictor(args, mode, logger):
config = inference.Config(model_file_path, params_file_path)
+ if hasattr(args, 'precision'):
+ if args.precision == "fp16" and args.use_tensorrt:
+ precision = inference.PrecisionType.Half
+ elif args.precision == "int8":
+ precision = inference.PrecisionType.Int8
+ else:
+ precision = inference.PrecisionType.Float32
+ else:
+ precision = inference.PrecisionType.Float32
+
if args.use_gpu:
config.enable_use_gpu(args.gpu_mem, 0)
if args.use_tensorrt:
@@ -244,7 +331,9 @@ def create_predictor(args, mode, logger):
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False)
-
+ config.switch_ir_optim(True)
+ if mode == 'structure':
+ config.switch_ir_optim(False)
# create predictor
predictor = inference.create_predictor(config)
input_names = predictor.get_input_names()
@@ -255,7 +344,7 @@ def create_predictor(args, mode, logger):
for output_name in output_names:
output_tensor = predictor.get_output_handle(output_name)
output_tensors.append(output_tensor)
- return predictor, input_tensor, output_tensors
+ return predictor, input_tensor, output_tensors, config
def draw_e2e_res(dt_boxes, strs, img_path):
@@ -506,5 +595,30 @@ def draw_boxes(image, boxes, scores=None, drop_score=0.5):
return image
+def get_current_memory_mb(gpu_id=None):
+ """
+ It is used to Obtain the memory usage of the CPU and GPU during the running of the program.
+ And this function Current program is time-consuming.
+ """
+ import pynvml
+ import psutil
+ import GPUtil
+ pid = os.getpid()
+ p = psutil.Process(pid)
+ info = p.memory_full_info()
+ cpu_mem = info.uss / 1024. / 1024.
+ gpu_mem = 0
+ gpu_percent = 0
+ if gpu_id is not None:
+ GPUs = GPUtil.getGPUs()
+ gpu_load = GPUs[gpu_id].load
+ gpu_percent = gpu_load
+ pynvml.nvmlInit()
+ handle = pynvml.nvmlDeviceGetHandleByIndex(0)
+ meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
+ gpu_mem = meminfo.used / 1024. / 1024.
+ return round(cpu_mem, 4), round(gpu_mem, 4), round(gpu_percent, 4)
+
+
if __name__ == '__main__':
pass
|