提交 2f074068 编写于 作者: W WenmuZhou

whl package support send model url

上级 ac56eba7
...@@ -20,6 +20,7 @@ from tqdm import tqdm ...@@ -20,6 +20,7 @@ from tqdm import tqdm
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
def download_with_progressbar(url, save_path): def download_with_progressbar(url, save_path):
logger = get_logger() logger = get_logger()
response = requests.get(url, stream=True) response = requests.get(url, stream=True)
...@@ -45,6 +46,7 @@ def maybe_download(model_storage_directory, url): ...@@ -45,6 +46,7 @@ def maybe_download(model_storage_directory, url):
os.path.join(model_storage_directory, 'inference.pdiparams') os.path.join(model_storage_directory, 'inference.pdiparams')
) or not os.path.exists( ) or not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdmodel')): os.path.join(model_storage_directory, 'inference.pdmodel')):
assert url.endswith('.tar'), 'Only supports tar compressed package'
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path)) print('download {} to {}'.format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True) os.makedirs(model_storage_directory, exist_ok=True)
...@@ -64,3 +66,17 @@ def maybe_download(model_storage_directory, url): ...@@ -64,3 +66,17 @@ def maybe_download(model_storage_directory, url):
f.write(file.read()) f.write(file.read())
os.remove(tmp_path) os.remove(tmp_path)
def is_link(s):
return s is not None and s.startswith('http')
def confirm_model_dir_url(model_dir, default_model_dir, default_url):
url = default_url
if model_dir is None or is_link(model_dir):
if is_link(model_dir):
url = model_dir
file_name = url.split('/')[-1][:-4]
model_dir = default_model_dir
model_dir = os.path.join(model_dir, file_name)
return model_dir, url
...@@ -30,7 +30,7 @@ from ppstructure.utility import init_args, draw_result ...@@ -30,7 +30,7 @@ from ppstructure.utility import init_args, draw_result
logger = get_logger() logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from ppocr.utils.network import maybe_download, download_with_progressbar from ppocr.utils.network import maybe_download, download_with_progressbar, confirm_model_dir_url, is_link
__all__ = ['PaddleStructure', 'draw_result', 'to_excel'] __all__ = ['PaddleStructure', 'draw_result', 'to_excel']
...@@ -70,16 +70,19 @@ class PaddleStructure(OCRSystem): ...@@ -70,16 +70,19 @@ class PaddleStructure(OCRSystem):
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
params.use_angle_cls = False params.use_angle_cls = False
# init model dir # init model dir
if params.det_model_dir is None: params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
params.det_model_dir = os.path.join(BASE_DIR, VERSION, 'det') os.path.join(BASE_DIR, VERSION, 'det'),
if params.rec_model_dir is None: model_urls['det'])
params.rec_model_dir = os.path.join(BASE_DIR, VERSION, 'rec') params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
if params.structure_model_dir is None: os.path.join(BASE_DIR, VERSION, 'rec'),
params.structure_model_dir = os.path.join(BASE_DIR, VERSION, 'structure') model_urls['rec'])
params.structure_model_dir, structure_url = confirm_model_dir_url(params.structure_model_dir,
os.path.join(BASE_DIR, VERSION, 'structure'),
model_urls['structure'])
# download model # download model
maybe_download(params.det_model_dir, model_urls['det']) maybe_download(params.det_model_dir, det_url)
maybe_download(params.rec_model_dir, model_urls['rec']) maybe_download(params.rec_model_dir, rec_url)
maybe_download(params.structure_model_dir, model_urls['structure']) maybe_download(params.structure_model_dir, structure_url)
if params.rec_char_dict_path is None: if params.rec_char_dict_path is None:
params.rec_char_type = 'EN' params.rec_char_type = 'EN'
...@@ -143,3 +146,24 @@ def main(): ...@@ -143,3 +146,24 @@ def main():
logger.info(item['res']) logger.info(item['res'])
save_res(result, save_folder, img_name) save_res(result, save_folder, img_name)
logger.info('result save to {}'.format(os.path.join(save_folder, img_name))) logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
if __name__ == '__main__':
table_engine = PaddleStructure(
output='/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/output/table',
show_log=True)
img_path = '/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/ppstructure/test_imgs/paper-image.jpg'
img = cv2.imread(img_path)
result = table_engine(img)
for line in result:
print(line)
from PIL import Image
font_path = '/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR//doc/fonts/simfang.ttf'
image = Image.open(img_path).convert('RGB')
im_show = draw_result(image, result,
font_path='/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR//doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册