# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys __dir__ = os.path.dirname(__file__) sys.path.append(os.path.join(__dir__, '')) import cv2 import numpy as np from pathlib import Path import tarfile import requests from tqdm import tqdm from tools.infer import predict_system from ppocr.utils.utility import initial_logger logger = initial_logger() from ppocr.utils.utility import check_and_read_gif, get_image_file_list __all__ = ['PaddleOCR'] model_params = { 'det': 'https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar', 'rec': 'https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar', } SUPPORT_DET_MODEL = ['DB'] SUPPORT_REC_MODEL = ['CRNN'] BASE_DIR = os.path.expanduser("~/.paddleocr/") def download_with_progressbar(url, save_path): response = requests.get(url, stream=True) total_size_in_bytes = int(response.headers.get('content-length', 0)) block_size = 1024 # 1 Kibibyte progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) with open(save_path, 'wb') as file: for data in response.iter_content(block_size): progress_bar.update(len(data)) file.write(data) progress_bar.close() if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: logger.error("ERROR, something went wrong") sys.exit(0) def maybe_download(model_storage_directory, url): # using custom model if not os.path.exists(os.path.join( model_storage_directory, 'model')) or not os.path.exists( os.path.join(model_storage_directory, 'params')): tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) print('download {} to {}'.format(url, tmp_path)) os.makedirs(model_storage_directory, exist_ok=True) download_with_progressbar(url, tmp_path) with tarfile.open(tmp_path, 'r') as tarObj: for member in tarObj.getmembers(): if "model" in member.name: filename = 'model' elif "params" in member.name: filename = 'params' else: continue file = tarObj.extractfile(member) with open( os.path.join(model_storage_directory, filename), 'wb') as f: f.write(file.read()) os.remove(tmp_path) def parse_args(): import argparse def str2bool(v): return v.lower() in ("true", "t", "1") parser = argparse.ArgumentParser() # params for prediction engine parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--gpu_mem", type=int, default=8000) # params for text detector parser.add_argument("--image_dir", type=str) parser.add_argument("--det_algorithm", type=str, default='DB') parser.add_argument("--det_model_dir", type=str, default=None) parser.add_argument("--det_max_side_len", type=float, default=960) # DB parmas parser.add_argument("--det_db_thresh", type=float, default=0.3) parser.add_argument("--det_db_box_thresh", type=float, default=0.5) parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0) # EAST parmas parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) # params for text recognizer parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_model_dir", type=str, default=None) parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") parser.add_argument("--rec_char_type", type=str, default='ch') parser.add_argument("--rec_batch_num", type=int, default=30) parser.add_argument("--max_text_length", type=int, default=25) parser.add_argument( "--rec_char_dict_path", type=str, default="./ppocr/utils/ppocr_keys_v1.txt") parser.add_argument("--use_space_char", type=bool, default=True) parser.add_argument("--enable_mkldnn", type=bool, default=False) parser.add_argument("--det", type=str2bool, default=True) parser.add_argument("--rec", type=str2bool, default=True) parser.add_argument("--use_zero_copy_run", type=bool, default=False) return parser.parse_args() class PaddleOCR(predict_system.TextSystem): def __init__(self, **kwargs): """ paddleocr package args: **kwargs: other params show in paddleocr --help """ postprocess_params = parse_args() postprocess_params.__dict__.update(**kwargs) # init model dir if postprocess_params.det_model_dir is None: postprocess_params.det_model_dir = os.path.join(BASE_DIR, 'det') if postprocess_params.rec_model_dir is None: postprocess_params.rec_model_dir = os.path.join(BASE_DIR, 'rec') print(postprocess_params) # download model maybe_download(postprocess_params.det_model_dir, model_params['det']) maybe_download(postprocess_params.rec_model_dir, model_params['rec']) if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL: logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) sys.exit(0) if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL: logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL)) sys.exit(0) postprocess_params.rec_char_dict_path = Path( __file__).parent / postprocess_params.rec_char_dict_path # init det_model and rec_model super().__init__(postprocess_params) def ocr(self, img, det=True, rec=True): """ ocr with paddleocr args: img: img for ocr, support ndarray, img_path and list or ndarray det: use text detection or not, if false, only rec will be exec. default is True rec: use text recognition or not, if false, only det will be exec. default is True """ assert isinstance(img, (np.ndarray, list, str)) if isinstance(img, str): image_file = img img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.error("error in loading image:{}".format(image_file)) return None if det and rec: dt_boxes, rec_res = self.__call__(img) return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] elif det and not rec: dt_boxes, elapse = self.text_detector(img) if dt_boxes is None: return None return [box.tolist() for box in dt_boxes] else: if not isinstance(img, list): img = [img] rec_res, elapse = self.text_recognizer(img) return rec_res def main(): # for com args = parse_args() image_file_list = get_image_file_list(args.image_dir) if len(image_file_list) == 0: logger.error('no images find in {}'.format(args.image_dir)) return ocr_engine = PaddleOCR() for img_path in image_file_list: print(img_path) result = ocr_engine.ocr(img_path, det=args.det, rec=args.rec) for line in result: print(line)