提交 28b2d43e 编写于 作者: W WenmuZhou

paddleocr whl adaptation dygraph

上级 eade2ce8
include LICENSE.txt include LICENSE.txt
include README.md include README.md
recursive-include ppocr/utils *.txt utility.py character.py check.py recursive-include ppocr/utils *.txt utility.py logging.py
recursive-include ppocr/data/det *.py recursive-include ppocr/data/ *.py
recursive-include ppocr/postprocess *.py recursive-include ppocr/postprocess *.py
recursive-include ppocr/postprocess/lanms *.* recursive-include tools/infer *.py
recursive-include tools/infer *.py \ No newline at end of file
...@@ -261,6 +261,61 @@ im_show.save('result.jpg') ...@@ -261,6 +261,61 @@ im_show.save('result.jpg')
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_dir} --rec_model_dir {your_rec_model_dir} --rec_char_dict_path {your_rec_char_dict_path} --cls_model_dir {your_cls_model_dir} --use_angle_cls true --cls true paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_dir} --rec_model_dir {your_rec_model_dir} --rec_char_dict_path {your_rec_char_dict_path} --cls_model_dir {your_cls_model_dir} --use_angle_cls true --cls true
``` ```
### 使用网络图片或者numpy数组作为输入
1. 网络图片
代码使用
```python
from paddleocr import PaddleOCR, draw_ocr
# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换
# 参数依次为`ch`, `en`, `french`, `german`, `korean`, `japan`。
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
img_path = 'http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg'
result = ocr.ocr(img_path, cls=True)
for line in result:
print(line)
# 显示结果
from PIL import Image
image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
命令行模式
```bash
paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg --use_angle_cls=true
```
2. numpy数组
仅通过代码使用时支持numpy数组作为输入
```python
from paddleocr import PaddleOCR, draw_ocr
# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换
# 参数依次为`ch`, `en`, `french`, `german`, `korean`, `japan`。
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs/11.jpg'
img = cv2.imread(img_path)
# img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY), 如果你自己训练的模型支持灰度图,可以将这句话的注释取消
result = ocr.ocr(img_path, cls=True)
for line in result:
print(line)
# 显示结果
from PIL import Image
image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
## 参数说明 ## 参数说明
| 字段 | 说明 | 默认值 | | 字段 | 说明 | 默认值 |
...@@ -285,6 +340,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_ ...@@ -285,6 +340,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_
| max_text_length | 识别算法能识别的最大文字长度 | 25 | | max_text_length | 识别算法能识别的最大文字长度 | 25 |
| rec_char_dict_path | 识别模型字典路径,当rec_model_dir使用方式2传参时需要修改为自己的字典路径 | ./ppocr/utils/ppocr_keys_v1.txt | | rec_char_dict_path | 识别模型字典路径,当rec_model_dir使用方式2传参时需要修改为自己的字典路径 | ./ppocr/utils/ppocr_keys_v1.txt |
| use_space_char | 是否识别空格 | TRUE | | use_space_char | 是否识别空格 | TRUE |
| drop_score | 对输出按照分数(来自于识别模型)进行过滤,低于此分数的不返回 | 0.5 |
| use_angle_cls | 是否加载分类模型 | FALSE | | use_angle_cls | 是否加载分类模型 | FALSE |
| cls_model_dir | 分类模型所在文件夹。传参方式有两种,1. None: 自动下载内置模型到 `~/.paddleocr/cls`;2.自己转换好的inference模型路径,模型路径下必须包含model和params文件 | None | | cls_model_dir | 分类模型所在文件夹。传参方式有两种,1. None: 自动下载内置模型到 `~/.paddleocr/cls`;2.自己转换好的inference模型路径,模型路径下必须包含model和params文件 | None |
| cls_image_shape | 分类算法的输入图片尺寸 | "3, 48, 192" | | cls_image_shape | 分类算法的输入图片尺寸 | "3, 48, 192" |
...@@ -295,4 +351,4 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_ ...@@ -295,4 +351,4 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_
| lang | 模型语言类型,目前支持 中文(ch)和英文(en) | ch | | lang | 模型语言类型,目前支持 中文(ch)和英文(en) | ch |
| det | 前向时使用启动检测 | TRUE | | det | 前向时使用启动检测 | TRUE |
| rec | 前向时是否启动识别 | TRUE | | rec | 前向时是否启动识别 | TRUE |
| cls | 前向时是否启动分类 | FALSE | | cls | 前向时是否启动分类 (命令行模式下使用use_angle_cls控制前向是否启动分类) | FALSE |
...@@ -271,6 +271,59 @@ im_show.save('result.jpg') ...@@ -271,6 +271,59 @@ im_show.save('result.jpg')
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_dir} --rec_model_dir {your_rec_model_dir} --rec_char_dict_path {your_rec_char_dict_path} --cls_model_dir {your_cls_model_dir} --use_angle_cls true --cls true paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_dir} --rec_model_dir {your_rec_model_dir} --rec_char_dict_path {your_rec_char_dict_path} --cls_model_dir {your_cls_model_dir} --use_angle_cls true --cls true
``` ```
### Use web images or numpy array as input
1. Web image
Use by code
```python
from paddleocr import PaddleOCR, draw_ocr
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
img_path = 'http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg'
result = ocr.ocr(img_path, cls=True)
for line in result:
print(line)
# show result
from PIL import Image
image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
Use by command line
```bash
paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg --use_angle_cls=true
```
2. Numpy array
Support numpy array as input only when used by code
```python
from paddleocr import PaddleOCR, draw_ocr
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs/11.jpg'
img = cv2.imread(img_path)
# img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY), If your own training model supports grayscale images, you can uncomment this line
result = ocr.ocr(img_path, cls=True)
for line in result:
print(line)
# show result
from PIL import Image
image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
## Parameter Description ## Parameter Description
| Parameter | Description | Default value | | Parameter | Description | Default value |
...@@ -295,6 +348,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_ ...@@ -295,6 +348,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_
| max_text_length | The maximum text length that the recognition algorithm can recognize | 25 | | max_text_length | The maximum text length that the recognition algorithm can recognize | 25 |
| rec_char_dict_path | the alphabet path which needs to be modified to your own path when `rec_model_Name` use mode 2 | ./ppocr/utils/ppocr_keys_v1.txt | | rec_char_dict_path | the alphabet path which needs to be modified to your own path when `rec_model_Name` use mode 2 | ./ppocr/utils/ppocr_keys_v1.txt |
| use_space_char | Whether to recognize spaces | TRUE | | use_space_char | Whether to recognize spaces | TRUE |
| drop_score | Filter the output by score (from the recognition model), and those below this score will not be returned | 0.5 |
| use_angle_cls | Whether to load classification model | FALSE | | use_angle_cls | Whether to load classification model | FALSE |
| cls_model_dir | the classification inference model folder. There are two ways to transfer parameters, 1. None: Automatically download the built-in model to `~/.paddleocr/cls`; 2. The path of the inference model converted by yourself, the model and params files must be included in the model path | None | | cls_model_dir | the classification inference model folder. There are two ways to transfer parameters, 1. None: Automatically download the built-in model to `~/.paddleocr/cls`; 2. The path of the inference model converted by yourself, the model and params files must be included in the model path | None |
| cls_image_shape | image shape of classification algorithm | "3,48,192" | | cls_image_shape | image shape of classification algorithm | "3,48,192" |
...@@ -305,4 +359,4 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_ ...@@ -305,4 +359,4 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_
| lang | The support language, now only Chinese(ch)、English(en)、French(french)、German(german)、Korean(korean)、Japanese(japan) are supported | ch | | lang | The support language, now only Chinese(ch)、English(en)、French(french)、German(german)、Korean(korean)、Japanese(japan) are supported | ch |
| det | Enable detction when `ppocr.ocr` func exec | TRUE | | det | Enable detction when `ppocr.ocr` func exec | TRUE |
| rec | Enable recognition when `ppocr.ocr` func exec | TRUE | | rec | Enable recognition when `ppocr.ocr` func exec | TRUE |
| cls | Enable classification when `ppocr.ocr` func exec | FALSE | | cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE |
...@@ -26,17 +26,50 @@ import requests ...@@ -26,17 +26,50 @@ import requests
from tqdm import tqdm from tqdm import tqdm
from tools.infer import predict_system from tools.infer import predict_system
from ppocr.utils.utility import initial_logger from ppocr.utils.logging import get_logger
logger = initial_logger() logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list from ppocr.utils.utility import check_and_read_gif, get_image_file_list
__all__ = ['PaddleOCR'] __all__ = ['PaddleOCR']
model_params = { model_urls = {
'det': 'https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar', 'det':
'rec': 'https://paddleocr.bj.bcebos.com/20-09-22/mobile/det/ch_ppocr_mobile_v1.1_det_infer.tar',
'https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar', 'rec': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/rec/ch_ppocr_mobile_v1.1_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
},
'en': {
'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/en/en_ppocr_mobile_v1.1_rec_infer.tar',
'dict_path': './ppocr/utils/ic15_dict.txt'
},
'french': {
'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/fr/french_ppocr_mobile_v1.1_rec_infer.tar',
'dict_path': './ppocr/utils/dict/french_dict.txt'
},
'german': {
'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/ge/german_ppocr_mobile_v1.1_rec_infer.tar',
'dict_path': './ppocr/utils/dict/german_dict.txt'
},
'korean': {
'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/kr/korean_ppocr_mobile_v1.1_rec_infer.tar',
'dict_path': './ppocr/utils/dict/korean_dict.txt'
},
'japan': {
'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/jp/japan_ppocr_mobile_v1.1_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt'
}
},
'cls':
'https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile_v1.1_cls_infer.tar'
} }
SUPPORT_DET_MODEL = ['DB'] SUPPORT_DET_MODEL = ['DB']
...@@ -54,8 +87,8 @@ def download_with_progressbar(url, save_path): ...@@ -54,8 +87,8 @@ def download_with_progressbar(url, save_path):
progress_bar.update(len(data)) progress_bar.update(len(data))
file.write(data) file.write(data)
progress_bar.close() progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
logger.error("ERROR, something went wrong") logger.error("Something went wrong while downloading models")
sys.exit(0) sys.exit(0)
...@@ -63,7 +96,7 @@ def maybe_download(model_storage_directory, url): ...@@ -63,7 +96,7 @@ def maybe_download(model_storage_directory, url):
# using custom model # using custom model
if not os.path.exists(os.path.join( if not os.path.exists(os.path.join(
model_storage_directory, 'model')) or not os.path.exists( model_storage_directory, 'model')) or not os.path.exists(
os.path.join(model_storage_directory, 'params')): os.path.join(model_storage_directory, 'params')):
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path)) print('download {} to {}'.format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True) os.makedirs(model_storage_directory, exist_ok=True)
...@@ -84,53 +117,102 @@ def maybe_download(model_storage_directory, url): ...@@ -84,53 +117,102 @@ def maybe_download(model_storage_directory, url):
os.remove(tmp_path) os.remove(tmp_path)
def parse_args(): def parse_args(mMain=True, add_help=True):
import argparse import argparse
def str2bool(v): def str2bool(v):
return v.lower() in ("true", "t", "1") return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser() if mMain:
# params for prediction engine parser = argparse.ArgumentParser(add_help=add_help)
parser.add_argument("--use_gpu", type=str2bool, default=True) # params for prediction engine
parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--gpu_mem", type=int, default=8000) parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000)
# params for text detector
parser.add_argument("--image_dir", type=str) # params for text detector
parser.add_argument("--det_algorithm", type=str, default='DB') parser.add_argument("--image_dir", type=str)
parser.add_argument("--det_model_dir", type=str, default=None) parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_max_side_len", type=float, default=960) parser.add_argument("--det_model_dir", type=str, default=None)
parser.add_argument("--det_limit_side_len", type=float, default=960)
# DB parmas parser.add_argument("--det_limit_type", type=str, default='max')
parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5) # DB parmas
parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0) parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
# EAST parmas parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0)
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) # EAST parmas
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
# params for text recognizer parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str, default=None) # params for text recognizer
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_char_type", type=str, default='ch') parser.add_argument("--rec_model_dir", type=str, default=None)
parser.add_argument("--rec_batch_num", type=int, default=30) parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--max_text_length", type=int, default=25) parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument( parser.add_argument("--rec_batch_num", type=int, default=30)
"--rec_char_dict_path", parser.add_argument("--max_text_length", type=int, default=25)
type=str, parser.add_argument("--rec_char_dict_path", type=str, default=None)
default="./ppocr/utils/ppocr_keys_v1.txt") parser.add_argument("--use_space_char", type=bool, default=True)
parser.add_argument("--use_space_char", type=bool, default=True) parser.add_argument("--drop_score", type=float, default=0.5)
parser.add_argument("--enable_mkldnn", type=bool, default=False)
# params for text classifier
parser.add_argument("--det", type=str2bool, default=True) parser.add_argument("--cls_model_dir", type=str, default=None)
parser.add_argument("--rec", type=str2bool, default=True) parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
parser.add_argument("--use_zero_copy_run", type=bool, default=False) parser.add_argument("--label_list", type=list, default=['0', '180'])
return parser.parse_args() parser.add_argument("--cls_batch_num", type=int, default=30)
parser.add_argument("--cls_thresh", type=float, default=0.9)
parser.add_argument("--enable_mkldnn", type=bool, default=False)
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--lang", type=str, default='ch')
parser.add_argument("--det", type=str2bool, default=True)
parser.add_argument("--rec", type=str2bool, default=True)
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
return parser.parse_args()
else:
return argparse.Namespace(use_gpu=True,
ir_optim=True,
use_tensorrt=False,
gpu_mem=8000,
image_dir='',
det_algorithm='DB',
det_model_dir=None,
det_limit_side_len=960,
det_limit_type='max',
det_db_thresh=0.3,
det_db_box_thresh=0.5,
det_db_unclip_ratio=2.0,
det_east_score_thresh=0.8,
det_east_cover_thresh=0.1,
det_east_nms_thresh=0.2,
rec_algorithm='CRNN',
rec_model_dir=None,
rec_image_shape="3, 32, 320",
rec_char_type='ch',
rec_batch_num=30,
max_text_length=25,
rec_char_dict_path=None,
use_space_char=True,
drop_score=0.5,
cls_model_dir=None,
cls_image_shape="3, 48, 192",
label_list=['0', '180'],
cls_batch_num=30,
cls_thresh=0.9,
enable_mkldnn=False,
use_zero_copy_run=False,
use_pdserving=False,
lang='ch',
det=True,
rec=True,
use_angle_cls=False
)
class PaddleOCR(predict_system.TextSystem): class PaddleOCR(predict_system.TextSystem):
...@@ -140,18 +222,31 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -140,18 +222,31 @@ class PaddleOCR(predict_system.TextSystem):
args: args:
**kwargs: other params show in paddleocr --help **kwargs: other params show in paddleocr --help
""" """
postprocess_params = parse_args() postprocess_params = parse_args(mMain=False, add_help=False)
postprocess_params.__dict__.update(**kwargs) postprocess_params.__dict__.update(**kwargs)
self.use_angle_cls = postprocess_params.use_angle_cls
lang = postprocess_params.lang
assert lang in model_urls[
'rec'], 'param lang must in {}, but got {}'.format(
model_urls['rec'].keys(), lang)
if postprocess_params.rec_char_dict_path is None:
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
'dict_path']
# init model dir # init model dir
if postprocess_params.det_model_dir is None: if postprocess_params.det_model_dir is None:
postprocess_params.det_model_dir = os.path.join(BASE_DIR, 'det') postprocess_params.det_model_dir = os.path.join(BASE_DIR, 'det')
if postprocess_params.rec_model_dir is None: if postprocess_params.rec_model_dir is None:
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, 'rec') postprocess_params.rec_model_dir = os.path.join(
BASE_DIR, 'rec/{}'.format(lang))
if postprocess_params.cls_model_dir is None:
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
print(postprocess_params) print(postprocess_params)
# download model # download model
maybe_download(postprocess_params.det_model_dir, model_params['det']) maybe_download(postprocess_params.det_model_dir, model_urls['det'])
maybe_download(postprocess_params.rec_model_dir, model_params['rec']) maybe_download(postprocess_params.rec_model_dir,
model_urls['rec'][lang]['url'])
maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL: if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
...@@ -166,7 +261,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -166,7 +261,7 @@ class PaddleOCR(predict_system.TextSystem):
# init det_model and rec_model # init det_model and rec_model
super().__init__(postprocess_params) super().__init__(postprocess_params)
def ocr(self, img, det=True, rec=True): def ocr(self, img, det=True, rec=True, cls=False):
""" """
ocr with paddleocr ocr with paddleocr
args: args:
...@@ -175,7 +270,16 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -175,7 +270,16 @@ class PaddleOCR(predict_system.TextSystem):
rec: use text recognition or not, if false, only det will be exec. default is True rec: use text recognition or not, if false, only det will be exec. default is True
""" """
assert isinstance(img, (np.ndarray, list, str)) assert isinstance(img, (np.ndarray, list, str))
if isinstance(img, list) and det == True:
logger.error('When input a list of images, det must be false')
exit(0)
self.use_angle_cls = cls
if isinstance(img, str): if isinstance(img, str):
# download net image
if img.startswith('http'):
download_with_progressbar(img, 'tmp.jpg')
img = 'tmp.jpg'
image_file = img image_file = img
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
if not flag: if not flag:
...@@ -183,6 +287,8 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -183,6 +287,8 @@ class PaddleOCR(predict_system.TextSystem):
if img is None: if img is None:
logger.error("error in loading image:{}".format(image_file)) logger.error("error in loading image:{}".format(image_file))
return None return None
if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if det and rec: if det and rec:
dt_boxes, rec_res = self.__call__(img) dt_boxes, rec_res = self.__call__(img)
return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
...@@ -194,20 +300,34 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -194,20 +300,34 @@ class PaddleOCR(predict_system.TextSystem):
else: else:
if not isinstance(img, list): if not isinstance(img, list):
img = [img] img = [img]
if self.use_angle_cls:
img, cls_res, elapse = self.text_classifier(img)
if not rec:
return cls_res
rec_res, elapse = self.text_recognizer(img) rec_res, elapse = self.text_recognizer(img)
return rec_res return rec_res
def main(): def main():
# for com # for cmd
args = parse_args() args = parse_args(mMain=True)
image_file_list = get_image_file_list(args.image_dir) image_dir = args.image_dir
if image_dir.startswith('http'):
download_with_progressbar(image_dir, 'tmp.jpg')
image_file_list = ['tmp.jpg']
else:
image_file_list = get_image_file_list(args.image_dir)
if len(image_file_list) == 0: if len(image_file_list) == 0:
logger.error('no images find in {}'.format(args.image_dir)) logger.error('no images find in {}'.format(args.image_dir))
return return
ocr_engine = PaddleOCR()
ocr_engine = PaddleOCR(**(args.__dict__))
for img_path in image_file_list: for img_path in image_file_list:
print(img_path) logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
result = ocr_engine.ocr(img_path, det=args.det, rec=args.rec) result = ocr_engine.ocr(img_path,
for line in result: det=args.det,
print(line) rec=args.rec,
\ No newline at end of file cls=args.use_angle_cls)
if result is not None:
for line in result:
logger.info(line)
...@@ -32,7 +32,7 @@ setup( ...@@ -32,7 +32,7 @@ setup(
package_dir={'paddleocr': ''}, package_dir={'paddleocr': ''},
include_package_data=True, include_package_data=True,
entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]}, entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]},
version='0.0.3', version='2.0',
install_requires=requirements, install_requires=requirements,
license='Apache License 2.0', license='Apache License 2.0',
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices', description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
import sys import sys
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
...@@ -30,12 +31,15 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif ...@@ -30,12 +31,15 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from tools.infer.utility import draw_ocr_box_txt from tools.infer.utility import draw_ocr_box_txt
logger = get_logger()
class TextSystem(object): class TextSystem(object):
def __init__(self, args): def __init__(self, args):
self.text_detector = predict_det.TextDetector(args) self.text_detector = predict_det.TextDetector(args)
self.text_recognizer = predict_rec.TextRecognizer(args) self.text_recognizer = predict_rec.TextRecognizer(args)
self.use_angle_cls = args.use_angle_cls self.use_angle_cls = args.use_angle_cls
self.drop_score = args.drop_score
if self.use_angle_cls: if self.use_angle_cls:
self.text_classifier = predict_cls.TextClassifier(args) self.text_classifier = predict_cls.TextClassifier(args)
...@@ -81,7 +85,8 @@ class TextSystem(object): ...@@ -81,7 +85,8 @@ class TextSystem(object):
def __call__(self, img): def __call__(self, img):
ori_im = img.copy() ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
logger.info("dt_boxes num : {}, elapse : {}".format(len(dt_boxes), elapse)) logger.info("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse))
if dt_boxes is None: if dt_boxes is None:
return None, None return None, None
img_crop_list = [] img_crop_list = []
...@@ -99,9 +104,16 @@ class TextSystem(object): ...@@ -99,9 +104,16 @@ class TextSystem(object):
len(img_crop_list), elapse)) len(img_crop_list), elapse))
rec_res, elapse = self.text_recognizer(img_crop_list) rec_res, elapse = self.text_recognizer(img_crop_list)
logger.info("rec_res num : {}, elapse : {}".format(len(rec_res), elapse)) logger.info("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse))
# self.print_draw_crop_rec_res(img_crop_list, rec_res) # self.print_draw_crop_rec_res(img_crop_list, rec_res)
return dt_boxes, rec_res filter_boxes, filter_rec_res = [], []
for box, rec_reuslt in zip(dt_boxes, rec_res):
text, score = rec_reuslt
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_reuslt)
return filter_boxes, filter_rec_res
def sorted_boxes(dt_boxes): def sorted_boxes(dt_boxes):
...@@ -117,8 +129,8 @@ def sorted_boxes(dt_boxes): ...@@ -117,8 +129,8 @@ def sorted_boxes(dt_boxes):
_boxes = list(sorted_boxes) _boxes = list(sorted_boxes)
for i in range(num_boxes - 1): for i in range(num_boxes - 1):
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]): (_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i] tmp = _boxes[i]
_boxes[i] = _boxes[i + 1] _boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp _boxes[i + 1] = tmp
...@@ -143,12 +155,8 @@ def main(args): ...@@ -143,12 +155,8 @@ def main(args):
elapse = time.time() - starttime elapse = time.time() - starttime
logger.info("Predict time of %s: %.3fs" % (image_file, elapse)) logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
dt_num = len(dt_boxes) for text, score in rec_res:
for dno in range(dt_num): logger.info("{}, {:.3f}".format(text, score))
text, score = rec_res[dno]
if score >= drop_score:
text_str = "%s, %.3f" % (text, score)
logger.info(text_str)
if is_visualize: if is_visualize:
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
...@@ -174,5 +182,4 @@ def main(args): ...@@ -174,5 +182,4 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
logger = get_logger() main(utility.parse_args())
main(utility.parse_args()) \ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册