From a5f7511505662d047a3dd2db10512ed90c1e82bf Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Sat, 5 Jun 2021 17:33:50 +0800 Subject: [PATCH] mv download func to ppocr/utils/network.py --- paddleocr.py | 126 +++++++++++++---------------------------- ppocr/utils/network.py | 66 +++++++++++++++++++++ 2 files changed, 106 insertions(+), 86 deletions(-) create mode 100644 ppocr/utils/network.py diff --git a/paddleocr.py b/paddleocr.py index 1e4d94ff..708f20b1 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -21,15 +21,13 @@ sys.path.append(os.path.join(__dir__, '')) import cv2 import numpy as np from pathlib import Path -import tarfile -import requests -from tqdm import tqdm from tools.infer import predict_system from ppocr.utils.logging import get_logger logger = get_logger() from ppocr.utils.utility import check_and_read_gif, get_image_file_list +from ppocr.utils.network import maybe_download, download_with_progressbar from tools.infer.utility import draw_ocr, init_args, str2bool __all__ = ['PaddleOCR'] @@ -37,84 +35,84 @@ __all__ = ['PaddleOCR'] model_urls = { 'det': { 'ch': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar', 'en': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar' + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar' }, 'rec': { 'ch': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/ppocr_keys_v1.txt' }, 'en': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/en_dict.txt' }, 'french': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/french_dict.txt' }, 'german': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/german_dict.txt' }, 'korean': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/korean_dict.txt' }, 'japan': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/japan_dict.txt' }, 'chinese_cht': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt' }, 'ta': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/ta_dict.txt' }, 'te': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/te_dict.txt' }, 'ka': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/ka_dict.txt' }, 'latin': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/latin_dict.txt' }, 'arabic': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/arabic_dict.txt' }, 'cyrillic': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/cyrillic_dict.txt' }, 'devanagari': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/devanagari_dict.txt' } }, 'cls': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar' + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar' } SUPPORT_DET_MODEL = ['DB'] @@ -123,50 +121,6 @@ SUPPORT_REC_MODEL = ['CRNN'] BASE_DIR = os.path.expanduser("~/.paddleocr/") -def download_with_progressbar(url, save_path): - response = requests.get(url, stream=True) - total_size_in_bytes = int(response.headers.get('content-length', 0)) - block_size = 1024 # 1 Kibibyte - progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) - with open(save_path, 'wb') as file: - for data in response.iter_content(block_size): - progress_bar.update(len(data)) - file.write(data) - progress_bar.close() - if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes: - logger.error("Something went wrong while downloading models") - sys.exit(0) - - -def maybe_download(model_storage_directory, url): - # using custom model - tar_file_name_list = [ - 'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel' - ] - if not os.path.exists( - os.path.join(model_storage_directory, 'inference.pdiparams') - ) or not os.path.exists( - os.path.join(model_storage_directory, 'inference.pdmodel')): - tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) - print('download {} to {}'.format(url, tmp_path)) - os.makedirs(model_storage_directory, exist_ok=True) - download_with_progressbar(url, tmp_path) - with tarfile.open(tmp_path, 'r') as tarObj: - for member in tarObj.getmembers(): - filename = None - for tar_file_name in tar_file_name_list: - if tar_file_name in member.name: - filename = tar_file_name - if filename is None: - continue - file = tarObj.extractfile(member) - with open( - os.path.join(model_storage_directory, filename), - 'wb') as f: - f.write(file.read()) - os.remove(tmp_path) - - def parse_args(mMain=True): import argparse parser = init_args() @@ -194,10 +148,10 @@ class PaddleOCR(predict_system.TextSystem): args: **kwargs: other params show in paddleocr --help """ - postprocess_params = parse_args(mMain=False) - postprocess_params.__dict__.update(**kwargs) - self.use_angle_cls = postprocess_params.use_angle_cls - lang = postprocess_params.lang + params = parse_args(mMain=False) + params.__dict__.update(**kwargs) + self.use_angle_cls = params.use_angle_cls + lang = params.lang latin_lang = [ 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', @@ -223,46 +177,46 @@ class PaddleOCR(predict_system.TextSystem): lang = "devanagari" assert lang in model_urls[ 'rec'], 'param lang must in {}, but got {}'.format( - model_urls['rec'].keys(), lang) + model_urls['rec'].keys(), lang) if lang == "ch": det_lang = "ch" else: det_lang = "en" use_inner_dict = False - if postprocess_params.rec_char_dict_path is None: + if params.rec_char_dict_path is None: use_inner_dict = True - postprocess_params.rec_char_dict_path = model_urls['rec'][lang][ + params.rec_char_dict_path = model_urls['rec'][lang][ 'dict_path'] # init model dir - if postprocess_params.det_model_dir is None: - postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION, + if params.det_model_dir is None: + params.det_model_dir = os.path.join(BASE_DIR, VERSION, 'det', det_lang) - if postprocess_params.rec_model_dir is None: - postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION, + if params.rec_model_dir is None: + params.rec_model_dir = os.path.join(BASE_DIR, VERSION, 'rec', lang) - if postprocess_params.cls_model_dir is None: - postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls') - print(postprocess_params) + if params.cls_model_dir is None: + params.cls_model_dir = os.path.join(BASE_DIR, 'cls') # download model - maybe_download(postprocess_params.det_model_dir, + maybe_download(params.det_model_dir, model_urls['det'][det_lang]) - maybe_download(postprocess_params.rec_model_dir, + maybe_download(params.rec_model_dir, model_urls['rec'][lang]['url']) - maybe_download(postprocess_params.cls_model_dir, model_urls['cls']) + maybe_download(params.cls_model_dir, model_urls['cls']) - if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL: + if params.det_algorithm not in SUPPORT_DET_MODEL: logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) sys.exit(0) - if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL: + if params.rec_algorithm not in SUPPORT_REC_MODEL: logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL)) sys.exit(0) if use_inner_dict: - postprocess_params.rec_char_dict_path = str( - Path(__file__).parent / postprocess_params.rec_char_dict_path) + params.rec_char_dict_path = str( + Path(__file__).parent / params.rec_char_dict_path) + print(params) # init det_model and rec_model - super().__init__(postprocess_params) + super().__init__(params) def ocr(self, img, det=True, rec=True, cls=True): """ diff --git a/ppocr/utils/network.py b/ppocr/utils/network.py new file mode 100644 index 00000000..7d98f5c3 --- /dev/null +++ b/ppocr/utils/network.py @@ -0,0 +1,66 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import tarfile +import requests +from tqdm import tqdm + +from ppocr.utils.logging import get_logger + +def download_with_progressbar(url, save_path): + logger = get_logger() + response = requests.get(url, stream=True) + total_size_in_bytes = int(response.headers.get('content-length', 0)) + block_size = 1024 # 1 Kibibyte + progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) + with open(save_path, 'wb') as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes: + logger.error("Something went wrong while downloading models") + sys.exit(0) + + +def maybe_download(model_storage_directory, url): + # using custom model + tar_file_name_list = [ + 'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel' + ] + if not os.path.exists( + os.path.join(model_storage_directory, 'inference.pdiparams') + ) or not os.path.exists( + os.path.join(model_storage_directory, 'inference.pdmodel')): + tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) + print('download {} to {}'.format(url, tmp_path)) + os.makedirs(model_storage_directory, exist_ok=True) + download_with_progressbar(url, tmp_path) + with tarfile.open(tmp_path, 'r') as tarObj: + for member in tarObj.getmembers(): + filename = None + for tar_file_name in tar_file_name_list: + if tar_file_name in member.name: + filename = tar_file_name + if filename is None: + continue + file = tarObj.extractfile(member) + with open( + os.path.join(model_storage_directory, filename), + 'wb') as f: + f.write(file.read()) + os.remove(tmp_path) + -- GitLab