diff --git a/ppocr/utils/network.py b/ppocr/utils/network.py index 7d98f5c310866d4bd99b0ea80fac2fd7b9620a36..453abb693d4c0ed370c1031b677d5bf51661add9 100644 --- a/ppocr/utils/network.py +++ b/ppocr/utils/network.py @@ -20,6 +20,7 @@ from tqdm import tqdm from ppocr.utils.logging import get_logger + def download_with_progressbar(url, save_path): logger = get_logger() response = requests.get(url, stream=True) @@ -45,6 +46,7 @@ def maybe_download(model_storage_directory, url): os.path.join(model_storage_directory, 'inference.pdiparams') ) or not os.path.exists( os.path.join(model_storage_directory, 'inference.pdmodel')): + assert url.endswith('.tar'), 'Only supports tar compressed package' tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) print('download {} to {}'.format(url, tmp_path)) os.makedirs(model_storage_directory, exist_ok=True) @@ -64,3 +66,17 @@ def maybe_download(model_storage_directory, url): f.write(file.read()) os.remove(tmp_path) + +def is_link(s): + return s is not None and s.startswith('http') + + +def confirm_model_dir_url(model_dir, default_model_dir, default_url): + url = default_url + if model_dir is None or is_link(model_dir): + if is_link(model_dir): + url = model_dir + file_name = url.split('/')[-1][:-4] + model_dir = default_model_dir + model_dir = os.path.join(model_dir, file_name) + return model_dir, url diff --git a/ppstructure/paddlestructure.py b/ppstructure/paddlestructure.py index 03f905650d8d165ebbeec924e79a280cc592939d..f9bfc97990b7f49ee6a857faf313dfd73ec005ab 100644 --- a/ppstructure/paddlestructure.py +++ b/ppstructure/paddlestructure.py @@ -30,7 +30,7 @@ from ppstructure.utility import init_args, draw_result logger = get_logger() from ppocr.utils.utility import check_and_read_gif, get_image_file_list -from ppocr.utils.network import maybe_download, download_with_progressbar +from ppocr.utils.network import maybe_download, download_with_progressbar, confirm_model_dir_url, is_link __all__ = ['PaddleStructure', 'draw_result', 'to_excel'] @@ -70,16 +70,19 @@ class PaddleStructure(OCRSystem): logger.setLevel(logging.DEBUG) params.use_angle_cls = False # init model dir - if params.det_model_dir is None: - params.det_model_dir = os.path.join(BASE_DIR, VERSION, 'det') - if params.rec_model_dir is None: - params.rec_model_dir = os.path.join(BASE_DIR, VERSION, 'rec') - if params.structure_model_dir is None: - params.structure_model_dir = os.path.join(BASE_DIR, VERSION, 'structure') + params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir, + os.path.join(BASE_DIR, VERSION, 'det'), + model_urls['det']) + params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir, + os.path.join(BASE_DIR, VERSION, 'rec'), + model_urls['rec']) + params.structure_model_dir, structure_url = confirm_model_dir_url(params.structure_model_dir, + os.path.join(BASE_DIR, VERSION, 'structure'), + model_urls['structure']) # download model - maybe_download(params.det_model_dir, model_urls['det']) - maybe_download(params.rec_model_dir, model_urls['rec']) - maybe_download(params.structure_model_dir, model_urls['structure']) + maybe_download(params.det_model_dir, det_url) + maybe_download(params.rec_model_dir, rec_url) + maybe_download(params.structure_model_dir, structure_url) if params.rec_char_dict_path is None: params.rec_char_type = 'EN' @@ -143,3 +146,24 @@ def main(): logger.info(item['res']) save_res(result, save_folder, img_name) logger.info('result save to {}'.format(os.path.join(save_folder, img_name))) + + +if __name__ == '__main__': + table_engine = PaddleStructure( + output='/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/output/table', + show_log=True) + + img_path = '/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/ppstructure/test_imgs/paper-image.jpg' + img = cv2.imread(img_path) + result = table_engine(img) + for line in result: + print(line) + + from PIL import Image + + font_path = '/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR//doc/fonts/simfang.ttf' + image = Image.open(img_path).convert('RGB') + im_show = draw_result(image, result, + font_path='/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR//doc/fonts/simfang.ttf') + im_show = Image.fromarray(im_show) + im_show.save('result.jpg')