未验证 提交 2df020d0 编写于 作者: K KP 提交者: GitHub

Merge pull request #1864 from rainyfly/debug_ppocr_module

add ppocr module
# ch_pp-ocrv3
|模型名称|ch_pp-ocrv3|
| :--- | :---: |
|类别|图像-文字识别|
|网络|Differentiable Binarization+SVTR_LCNet|
|数据集|icdar2015数据集|
|是否支持Fine-tuning|否|
|模型大小|13M|
|最新更新日期|2022-05-11|
|数据指标|-|
## 一、模型基本信息
- ### 应用效果展示
- [OCR文字识别场景在线体验](https://www.paddlepaddle.org.cn/hub/scene/ocr)
- 样例结果示例:
<p align="center">
<img src="https://user-images.githubusercontent.com/22424850/167818854-96811631-d40c-4d07-9aae-b78d4514c917.jpg" width = "600" hspace='10'/> <br />
</p>
- ### 模型介绍
- PP-OCR是PaddleOCR自研的实用的超轻量OCR系统。在实现前沿算法的基础上,考虑精度与速度的平衡,进行模型瘦身和深度优化,使其尽可能满足产业落地需求。该系统包含文本检测和文本识别两个阶段,其中文本检测算法选用DB,文本识别算法选用CRNN,并在检测和识别模块之间添加文本方向分类器,以应对不同方向的文本识别。当前模块为PP-OCRv3,在PP-OCRv2的基础上,针对检测模型和识别模型,进行了共计9个方面的升级,进一步提升了模型效果。
<p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.5/doc/ppocrv3_framework.png" width="800" hspace='10'/> <br />
</p>
- 更多详情参考:[PP-OCRv3](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/PP-OCRv3_introduction.md)
## 二、安装
- ### 1、环境依赖
- paddlepaddle >= 2.2
- paddlehub >= 2.2 | [如何安装paddlehub](../../../../docs/docs_ch/get_start/installation.rst)
- ### 2、安装
- ```shell
$ hub install ch_pp-ocrv3
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测
- ### 1、命令行预测
- ```shell
$ hub run ch_pp-ocrv3 --input_path "/PATH/TO/IMAGE"
```
- 通过命令行方式实现文字识别模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、代码示例
- ```python
import paddlehub as hub
import cv2
ocr = hub.Module(name="ch_pp-ocrv3", enable_mkldnn=True) # mkldnn加速仅在CPU下有效
result = ocr.recognize_text(images=[cv2.imread('/PATH/TO/IMAGE')])
# or
# result = ocr.recognize_text(paths=['/PATH/TO/IMAGE'])
```
- ### 3、API
- ```python
__init__(text_detector_module=None, enable_mkldnn=False)
```
- 构造用于文本检测的模块
- **参数**
- text_detector_module(str): 文字检测PaddleHub Module名字,如设置为None,则默认使用[ch_pp-ocrv3_det Module](../ch_pp-ocrv3_det/)。其作用为检测图片当中的文本。
- enable_mkldnn(bool): 是否开启mkldnn加速CPU计算。该参数仅在CPU运行下设置有效。默认为False。
- ```python
def recognize_text(images=[],
paths=[],
use_gpu=False,
output_dir='ocr_result',
visualization=False,
box_thresh=0.5,
text_thresh=0.5,
angle_classification_thresh=0.9,
det_db_unclip_ratio=1.5)
```
- 预测API,检测输入图片中的所有中文文本的位置。
- **参数**
- paths (list\[str\]): 图片的路径;
- images (list\[numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\],BGR格式;
- use\_gpu (bool): 是否使用 GPU;**若使用GPU,请先设置CUDA_VISIBLE_DEVICES环境变量**
- box\_thresh (float): 检测文本框置信度的阈值;
- text\_thresh (float): 识别中文文本置信度的阈值;
- angle_classification_thresh(float): 文本角度分类置信度的阈值
- visualization (bool): 是否将识别结果保存为图片文件;
- output\_dir (str): 图片的保存路径,默认设为 ocr\_result;
- det\_db\_unclip\_ratio: 设置检测框的大小;
- **返回**
- res (list\[dict\]): 识别结果的列表,列表中每一个元素为 dict,各字段为:
- data (list\[dict\]): 识别文本结果,列表中每一个元素为 dict,各字段为:
- text(str): 识别得到的文本
- confidence(float): 识别文本结果置信度
- text_box_position(list): 文本框在原图中的像素坐标,4*2的矩阵,依次表示文本框左下、右下、右上、左上顶点的坐标
如果无识别结果则data为\[\]
- save_path (str, optional): 识别结果的保存路径,如不保存图片则save_path为''
## 四、服务部署
- PaddleHub Serving 可以部署一个目标检测的在线服务。
- ### 第一步:启动PaddleHub Serving
- 运行启动命令:
- ```shell
$ hub serving start -m ch_pp-ocrv3
```
- 这样就完成了一个目标检测的服务化API的部署,默认端口号为8866。
- **NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
- ### 第二步:发送预测请求
- 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
- ```python
import requests
import json
import cv2
import base64
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
# 发送HTTP请求
data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/ch_pp-ocrv3"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 打印预测结果
print(r.json()["results"])
```
## 五、更新历史
* 1.0.0
初始发布
- ```shell
$ hub install ch_pp-ocrv3==1.0.0
```
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import string
import numpy as np
class CharacterOps(object):
""" Convert between text-label and text-index
Args:
config: config from yaml file
"""
def __init__(self, config):
self.character_type = config['character_type']
self.max_text_len = config['max_text_length']
if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
# use the custom dictionary
elif self.character_type == "ch":
character_dict_path = config['character_dict_path']
add_space = False
if 'use_space_char' in config:
add_space = config['use_space_char']
self.character_str = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str.append(line)
if add_space:
self.character_str.append(" ")
dict_character = list(self.character_str)
elif self.character_type == "en_sensitive":
# same with ASTER setting (use 94 char).
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
else:
self.character_str = None
self.beg_str = "sos"
self.end_str = "eos"
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character
def encode(self, text):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
if self.character_type == "en":
text = text.lower()
text_list = []
for char in text:
if char not in self.dict:
continue
text_list.append(self.dict[char])
text = np.array(text_list)
return text
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
if is_remove_duplicate:
selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token
char_list = [self.character[text_id] for text_id in text_index[batch_idx][selection]]
if text_prob is not None:
conf_list = text_prob[batch_idx][selection]
else:
conf_list = [1] * len(selection)
if len(conf_list) == 0:
conf_list = [0]
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def get_char_num(self):
return len(self.character)
def get_beg_end_flag_idx(self, beg_or_end):
if self.loss_type == "attention":
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx"\
% beg_or_end
return idx
else:
err = "error in get_beg_end_flag_idx when using the loss %s"\
% (self.loss_type)
assert False, err
def get_ignored_tokens(self):
return [0] # for ctc blank
def cal_predicts_accuracy(char_ops, preds, preds_lod, labels, labels_lod, is_remove_duplicate=False):
"""
Calculate prediction accuracy
Args:
char_ops: CharacterOps
preds: preds result,text index
preds_lod: lod tensor of preds
labels: label of input image, text index
labels_lod: lod tensor of label
is_remove_duplicate: Whether to remove duplicate characters,
The default is False
Return:
acc: The accuracy of test set
acc_num: The correct number of samples predicted
img_num: The total sample number of the test set
"""
acc_num = 0
img_num = 0
for ino in range(len(labels_lod) - 1):
beg_no = preds_lod[ino]
end_no = preds_lod[ino + 1]
preds_text = preds[beg_no:end_no].reshape(-1)
preds_text = char_ops.decode(preds_text, is_remove_duplicate)
beg_no = labels_lod[ino]
end_no = labels_lod[ino + 1]
labels_text = labels[beg_no:end_no].reshape(-1)
labels_text = char_ops.decode(labels_text, is_remove_duplicate)
img_num += 1
if preds_text == labels_text:
acc_num += 1
acc = acc_num * 1.0 / img_num
return acc, acc_num, img_num
def cal_predicts_accuracy_srn(char_ops, preds, labels, max_text_len, is_debug=False):
acc_num = 0
img_num = 0
char_num = char_ops.get_char_num()
total_len = preds.shape[0]
img_num = int(total_len / max_text_len)
for i in range(img_num):
cur_label = []
cur_pred = []
for j in range(max_text_len):
if labels[j + i * max_text_len] != int(char_num - 1): #0
cur_label.append(labels[j + i * max_text_len][0])
else:
break
for j in range(max_text_len + 1):
if j < len(cur_label) and preds[j + i * max_text_len][0] != cur_label[j]:
break
elif j == len(cur_label) and j == max_text_len:
acc_num += 1
break
elif j == len(cur_label) and preds[j + i * max_text_len][0] == int(char_num - 1):
acc_num += 1
break
acc = acc_num * 1.0 / img_num
return acc, acc_num, img_num
def convert_rec_attention_infer_res(preds):
img_num = preds.shape[0]
target_lod = [0]
convert_ids = []
for ino in range(img_num):
end_pos = np.where(preds[ino, :] == 1)[0]
if len(end_pos) <= 1:
text_list = preds[ino, 1:]
else:
text_list = preds[ino, 1:end_pos[1]]
target_lod.append(target_lod[ino] + len(text_list))
convert_ids = convert_ids + list(text_list)
convert_ids = np.array(convert_ids)
convert_ids = convert_ids.reshape((-1, 1))
return convert_ids, target_lod
def convert_rec_label_to_lod(ori_labels):
img_num = len(ori_labels)
target_lod = [0]
convert_ids = []
for ino in range(img_num):
target_lod.append(target_lod[ino] + len(ori_labels[ino]))
convert_ids = convert_ids + list(ori_labels[ino])
convert_ids = np.array(convert_ids)
convert_ids = convert_ids.reshape((-1, 1))
return convert_ids, target_lod
此差异已折叠。
# -*- coding:utf-8 -*-
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import math
import cv2
import numpy as np
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
def draw_ocr(image, boxes, txts, scores, font_file, draw_txt=True, drop_score=0.5):
"""
Visualize the results of OCR detection and recognition
args:
image(Image|array): RGB image
boxes(list): boxes with shape(N, 4, 2)
txts(list): the texts
scores(list): txxs corresponding scores
draw_txt(bool): whether draw text or not
drop_score(float): only scores greater than drop_threshold will be visualized
return(array):
the visualized img
"""
if scores is None:
scores = [1] * len(boxes)
for (box, score) in zip(boxes, scores):
if score < drop_score or math.isnan(score):
continue
box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
if draw_txt:
img = np.array(resize_img(image, input_size=600))
txt_img = text_visual(txts, scores, font_file, img_h=img.shape[0], img_w=600, threshold=drop_score)
img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
return img
return image
def text_visual(texts, scores, font_file, img_h=400, img_w=600, threshold=0.):
"""
create new blank img and draw txt on it
args:
texts(list): the text will be draw
scores(list|None): corresponding score of each txt
img_h(int): the height of blank img
img_w(int): the width of blank img
return(array):
"""
if scores is not None:
assert len(texts) == len(scores), "The number of txts and corresponding scores must match"
def create_blank_img():
blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
blank_img[:, img_w - 1:] = 0
blank_img = Image.fromarray(blank_img).convert("RGB")
draw_txt = ImageDraw.Draw(blank_img)
return blank_img, draw_txt
blank_img, draw_txt = create_blank_img()
font_size = 20
txt_color = (0, 0, 0)
font = ImageFont.truetype(font_file, font_size, encoding="utf-8")
gap = font_size + 5
txt_img_list = []
count, index = 1, 0
for idx, txt in enumerate(texts):
index += 1
if scores[idx] < threshold or math.isnan(scores[idx]):
index -= 1
continue
first_line = True
while str_count(txt) >= img_w // font_size - 4:
tmp = txt
txt = tmp[:img_w // font_size - 4]
if first_line:
new_txt = str(index) + ': ' + txt
first_line = False
else:
new_txt = ' ' + txt
draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
txt = tmp[img_w // font_size - 4:]
if count >= img_h // gap - 1:
txt_img_list.append(np.array(blank_img))
blank_img, draw_txt = create_blank_img()
count = 0
count += 1
if first_line:
new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx])
else:
new_txt = " " + txt + " " + '%.3f' % (scores[idx])
draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
# whether add new blank img or not
if count >= img_h // gap - 1 and idx + 1 < len(texts):
txt_img_list.append(np.array(blank_img))
blank_img, draw_txt = create_blank_img()
count = 0
count += 1
txt_img_list.append(np.array(blank_img))
if len(txt_img_list) == 1:
blank_img = np.array(txt_img_list[0])
else:
blank_img = np.concatenate(txt_img_list, axis=1)
return np.array(blank_img)
def str_count(s):
"""
Count the number of Chinese characters,
a single English character and a single number
equal to half the length of Chinese characters.
args:
s(string): the input of string
return(int):
the number of Chinese characters
"""
import string
count_zh = count_pu = 0
s_len = len(s)
en_dg_count = 0
for c in s:
if c in string.ascii_letters or c.isdigit() or c.isspace():
en_dg_count += 1
elif c.isalpha():
count_zh += 1
else:
count_pu += 1
return s_len - math.ceil(en_dg_count / 2)
def resize_img(img, input_size=600):
img = np.array(img)
im_shape = img.shape
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
im_scale = float(input_size) / float(im_size_max)
im = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
return im
def get_image_ext(image):
if image.shape[2] == 4:
return ".png"
return ".jpg"
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
return _boxes
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
# ch_pp-ocrv3_det
|模型名称|ch_pp-ocrv3_det|
| :--- | :---: |
|类别|图像-文字检测|
|网络|Differentiable Binarization|
|数据集|icdar2015数据集|
|是否支持Fine-tuning|否|
|模型大小|3.7MB|
|最新更新日期|2022-05-11|
|数据指标|-|
## 一、模型基本信息
- ### 应用效果展示
- 样例结果示例:
<p align="center">
<img src="https://user-images.githubusercontent.com/22424850/167821705-f38496ef-daae-4de1-9363-3df20424f525.jpg" width="500" alt="package" >
</p>
- ### 模型介绍
- DB(Differentiable Binarization)是一种基于分割的文本检测算法。此类算法可以更好地处理弯曲等不规则形状文本,因此检测效果往往会更好。但其后处理步骤中将分割结果转化为检测框的流程复杂,耗时严重。DB将二值化阈值加入训练中学习,可以获得更准确的检测边界,从而简化后处理流程。该Module是PP-OCRv3的检测模型,对PP-OCRv2中的CML(Collaborative Mutual Learning) 协同互学习文本检测蒸馏策略进行了升级。
<p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.5/doc/ppocrv3_framework.png" width="800" hspace='10'/> <br />
</p>
- 更多详情参考:[PP-OCRv3](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/PP-OCRv3_introduction.md)
## 二、安装
- ### 1、环境依赖
- paddlepaddle >= 2.2
- paddlehub >= 2.2 | [如何安装paddlehub](../../../../docs/docs_ch/get_start/installation.rst)
- ### 2、安装
- ```shell
$ hub install ch_pp-ocrv3_det
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测
- ### 1、命令行预测
- ```shell
$ hub run ch_pp-ocrv3_det --input_path "/PATH/TO/IMAGE"
```
- 通过命令行方式实现文字识别模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、代码示例
- ```python
import paddlehub as hub
import cv2
text_detector = hub.Module(name="ch_pp-ocrv3_det", enable_mkldnn=True) # mkldnn加速仅在CPU下有效
result = text_detector.detect_text(images=[cv2.imread('/PATH/TO/IMAGE')])
# or
# result =text_detector.detect_text(paths=['/PATH/TO/IMAGE'])
```
- ### 3、API
- ```python
__init__(enable_mkldnn=False)
```
- 构造检测模块的对象
- **参数**
- enable_mkldnn(bool): 是否开启mkldnn加速CPU计算。该参数仅在CPU运行下设置有效。默认为False。
- ```python
def detect_text(paths=[],
images=[],
use_gpu=False,
output_dir='detection_result',
box_thresh=0.5,
visualization=False,
det_db_unclip_ratio=1.5)
```
- 预测API,检测输入图片中的所有中文文本的位置。
- **参数**
- paths (list\[str\]): 图片的路径;
- images (list\[numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\],BGR格式;
- use\_gpu (bool): 是否使用 GPU;**若使用GPU,请先设置CUDA_VISIBLE_DEVICES环境变量**
- box\_thresh (float): 检测文本框置信度的阈值;
- visualization (bool): 是否将识别结果保存为图片文件;
- output\_dir (str): 图片的保存路径,默认设为 detection\_result;
- det\_db\_unclip\_ratio: 设置检测框的大小;
- **返回**
- res (list\[dict\]): 识别结果的列表,列表中每一个元素为 dict,各字段为:
- data (list): 检测文本框结果,文本框在原图中的像素坐标,4*2的矩阵,依次表示文本框左下、右下、右上、左上顶点的坐标
- save_path (str): 识别结果的保存路径, 如不保存图片则save_path为''
## 四、服务部署
- PaddleHub Serving 可以部署一个目标检测的在线服务。
- ### 第一步:启动PaddleHub Serving
- 运行启动命令:
- ```shell
$ hub serving start -m ch_pp-ocrv3_det
```
- 这样就完成了一个目标检测的服务化API的部署,默认端口号为8866。
- **NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
- ### 第二步:发送预测请求
- 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
- ```python
import requests
import json
import cv2
import base64
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
# 发送HTTP请求
data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/ch_pp-ocrv3_det"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 打印预测结果
print(r.json()["results"])
```
## 五、更新历史
* 1.0.0
初始发布
- ```shell
$ hub install ch_pp-ocrv3_det==1.0.0
```
# -*- coding:utf-8 -*-
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import ast
import base64
import math
import os
import time
import cv2
import numpy as np
import paddle.fluid as fluid
import paddle.inference as paddle_infer
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import create_paddle_predictor
from paddle.fluid.core import PaddleTensor
from PIL import Image
import paddlehub as hub
from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
from paddlehub.module.module import serving
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
@moduleinfo(
name="ch_pp-ocrv3_det",
version="1.0.0",
summary=
"The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
type="cv/text_recognition")
class ChPPOCRv3Det(hub.Module):
def _initialize(self, enable_mkldnn=False):
"""
initialize with the necessary elements
"""
self.pretrained_model_path = os.path.join(self.directory, 'inference_model', 'ppocrv3_det')
self.enable_mkldnn = enable_mkldnn
self._set_config()
def check_requirements(self):
try:
import shapely, pyclipper
except:
raise ImportError(
'This module requires the shapely, pyclipper tools. The running environment does not meet the requirements. Please install the two packages.'
)
def _set_config(self):
"""
predictor config setting
"""
model_file_path = self.pretrained_model_path + '.pdmodel'
params_file_path = self.pretrained_model_path + '.pdiparams'
config = paddle_infer.Config(model_file_path, params_file_path)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
except:
use_gpu = False
if use_gpu:
config.enable_use_gpu(8000, 0)
else:
config.disable_gpu()
config.set_cpu_math_library_num_threads(6)
if self.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
config.disable_glog_info()
# use zero copy
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False)
self.predictor = paddle_infer.create_predictor(config)
input_names = self.predictor.get_input_names()
self.input_tensor = self.predictor.get_input_handle(input_names[0])
output_names = self.predictor.get_output_names()
self.output_tensors = []
for output_name in output_names:
output_tensor = self.predictor.get_output_handle(output_name)
self.output_tensors.append(output_tensor)
def read_images(self, paths=[]):
images = []
for img_path in paths:
assert os.path.isfile(img_path), "The {} isn't a valid file.".format(img_path)
img = cv2.imread(img_path)
if img is None:
logger.info("error in loading image:{}".format(img_path))
continue
images.append(img)
return images
def order_points_clockwise(self, pts):
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
diff = np.diff(pts, axis=1)
rect[1] = pts[np.argmin(diff)]
rect[3] = pts[np.argmax(diff)]
return rect
def clip_det_res(self, points, img_height, img_width):
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points
def filter_tag_det_res(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
box = self.order_points_clockwise(box)
box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 3 or rect_height <= 3:
continue
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
box = self.clip_det_res(box, img_height, img_width)
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def detect_text(self,
images=[],
paths=[],
use_gpu=False,
output_dir='detection_result',
visualization=False,
box_thresh=0.5,
det_db_unclip_ratio=1.5):
"""
Get the text box in the predicted images.
Args:
images (list(numpy.ndarray)): images data, shape of each is [H, W, C]. If images not paths
paths (list[str]): The paths of images. If paths not images
use_gpu (bool): Whether to use gpu. Default false.
output_dir (str): The directory to store output images.
visualization (bool): Whether to save image or not.
box_thresh(float): the threshold of the detected text box's confidence
det_db_unclip_ratio(float): unclip ratio for post processing in DB detection.
Returns:
res (list): The result of text detection box and save path of images.
"""
self.check_requirements()
from .processor import DBProcessTest, DBPostProcess, draw_boxes, get_image_ext
if use_gpu:
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
except:
raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id."
)
if images != [] and isinstance(images, list) and paths == []:
predicted_data = images
elif images == [] and isinstance(paths, list) and paths != []:
predicted_data = self.read_images(paths)
else:
raise TypeError("The input data is inconsistent with expectations.")
assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
preprocessor = DBProcessTest(params={'max_side_len': 960})
postprocessor = DBPostProcess(params={
'thresh': 0.3,
'box_thresh': 0.6,
'max_candidates': 1000,
'unclip_ratio': det_db_unclip_ratio
})
all_imgs = []
all_ratios = []
all_results = []
for original_image in predicted_data:
ori_im = original_image.copy()
im, ratio_list = preprocessor(original_image)
res = {'save_path': ''}
if im is None:
res['data'] = []
else:
im = im.copy()
self.input_tensor.copy_from_cpu(im)
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
outs_dict = {}
outs_dict['maps'] = outputs[0]
dt_boxes_list = postprocessor(outs_dict, [ratio_list])
dt_boxes = dt_boxes_list[0]
boxes = self.filter_tag_det_res(dt_boxes_list[0], original_image.shape)
res['data'] = boxes.astype(np.int).tolist()
all_imgs.append(im)
all_ratios.append(ratio_list)
if visualization:
img = Image.fromarray(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
draw_img = draw_boxes(img, boxes)
draw_img = np.array(draw_img)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
ext = get_image_ext(original_image)
saved_name = 'ndarray_{}{}'.format(time.time(), ext)
cv2.imwrite(os.path.join(output_dir, saved_name), draw_img[:, :, ::-1])
res['save_path'] = os.path.join(output_dir, saved_name)
all_results.append(res)
return all_results
@serving
def serving_method(self, images, **kwargs):
"""
Run as a service.
"""
images_decode = [base64_to_cv2(image) for image in images]
results = self.detect_text(images=images_decode, **kwargs)
return results
@runnable
def run_cmd(self, argvs):
"""
Run as a command
"""
self.parser = argparse.ArgumentParser(description="Run the %s module." % self.name,
prog='hub run %s' % self.name,
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, not required.")
self.add_module_config_arg()
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
results = self.detect_text(paths=[args.input_path],
use_gpu=args.use_gpu,
output_dir=args.output_dir,
det_db_unclip_ratio=args.det_db_unclip_ratio,
visualization=args.visualization)
return results
def add_module_config_arg(self):
"""
Add the command config options
"""
self.arg_config_group.add_argument('--use_gpu',
type=ast.literal_eval,
default=False,
help="whether use GPU or not")
self.arg_config_group.add_argument('--output_dir',
type=str,
default='detection_result',
help="The directory to save output images.")
self.arg_config_group.add_argument('--visualization',
type=ast.literal_eval,
default=False,
help="whether to save output as images.")
self.arg_config_group.add_argument('--det_db_unclip_ratio',
type=float,
default=1.5,
help="unclip ratio for post processing in DB detection.")
def add_module_input_arg(self):
"""
Add the command input options
"""
self.arg_input_group.add_argument('--input_path', type=str, default=None, help="diretory to image")
# -*- coding:utf-8 -*-
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import cv2
import numpy as np
import paddle
import pyclipper
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
from shapely.geometry import Polygon
class DBProcessTest(object):
"""
DB pre-process for Test mode
"""
def __init__(self, params):
super(DBProcessTest, self).__init__()
self.resize_type = 0
if 'test_image_shape' in params:
self.image_shape = params['test_image_shape']
self.resize_type = 1
if 'max_side_len' in params:
self.max_side_len = params['max_side_len']
else:
self.max_side_len = 2400
def resize_image_type0(self, img):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.max_side_len
h, w, _ = img.shape
# limit the max side
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
resize_h = int(h * ratio)
resize_w = int(w * ratio)
resize_h = max(int(round(resize_h / 32) * 32), 32)
resize_w = max(int(round(resize_w / 32) * 32), 32)
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
except:
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
# return img, np.array([h, w])
return img, [ratio_h, ratio_w]
def resize_image_type1(self, im):
resize_h, resize_w = self.image_shape
ori_h, ori_w = im.shape[:2] # (h, w, c)
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
return im, (ratio_h, ratio_w)
def normalize(self, im):
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
im = im.astype(np.float32, copy=False)
im = im / 255
im[:, :, 0] -= img_mean[0]
im[:, :, 1] -= img_mean[1]
im[:, :, 2] -= img_mean[2]
im[:, :, 0] /= img_std[0]
im[:, :, 1] /= img_std[1]
im[:, :, 2] /= img_std[2]
channel_swap = (2, 0, 1)
im = im.transpose(channel_swap)
return im
def __call__(self, im):
src_h, src_w, _ = im.shape
if self.resize_type == 0:
im, (ratio_h, ratio_w) = self.resize_image_type0(im)
else:
im, (ratio_h, ratio_w) = self.resize_image_type1(im)
im = self.normalize(im)
im = im[np.newaxis, :]
return [im, (src_h, src_w, ratio_h, ratio_w)]
class DBPostProcess(object):
"""
The post process for Differentiable Binarization (DB).
"""
def __init__(self, params):
self.thresh = params['thresh']
self.box_thresh = params['box_thresh']
self.max_candidates = params['max_candidates']
self.unclip_ratio = params['unclip_ratio']
self.min_size = 3
self.dilation_kernel = None
self.score_mode = 'fast'
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
if len(outs) == 3:
img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2:
contours, _ = outs[0], outs[1]
num_contours = min(len(contours), self.max_candidates)
boxes = []
scores = []
for index in range(num_contours):
contour = contours[index]
points, sside = self.get_mini_boxes(contour)
if sside < self.min_size:
continue
points = np.array(points)
if self.score_mode == "fast":
score = self.box_score_fast(pred, points.reshape(-1, 2))
else:
score = self.box_score_slow(pred, contour)
if self.box_thresh > score:
continue
box = self.unclip(points).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.astype(np.int16))
scores.append(score)
return np.array(boxes, dtype=np.int16), scores
def unclip(self, box):
unclip_ratio = self.unclip_ratio
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = np.array(offset.Execute(distance))
return expanded
def get_mini_boxes(self, contour):
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [points[index_1], points[index_2], points[index_3], points[index_4]]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
'''
box_score_fast: use bbox mean score as the mean score
'''
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour):
'''
box_score_slow: use polyon mean score as the mean score
'''
h, w = bitmap.shape[:2]
contour = contour.copy()
contour = np.reshape(contour, (-1, 2))
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
contour[:, 0] = contour[:, 0] - xmin
contour[:, 1] = contour[:, 1] - ymin
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps']
pred = pred[:, 0, :, :]
segmentation = pred > self.thresh
boxes_batch = []
for batch_index in range(pred.shape[0]):
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
mask = segmentation[batch_index]
tmp_boxes, tmp_scores = self.boxes_from_bitmap(pred[batch_index], mask, src_w, src_h)
boxes_batch.append(tmp_boxes)
return boxes_batch
def draw_boxes(image, boxes, scores=None, drop_score=0.5):
img = image.copy()
draw = ImageDraw.Draw(img)
if scores is None:
scores = [1] * len(boxes)
for (box, score) in zip(boxes, scores):
if score < drop_score:
continue
draw.line([(box[0][0], box[0][1]), (box[1][0], box[1][1])], fill='red')
draw.line([(box[1][0], box[1][1]), (box[2][0], box[2][1])], fill='red')
draw.line([(box[2][0], box[2][1]), (box[3][0], box[3][1])], fill='red')
draw.line([(box[3][0], box[3][1]), (box[0][0], box[0][1])], fill='red')
draw.line([(box[0][0] - 1, box[0][1] + 1), (box[1][0] - 1, box[1][1] + 1)], fill='red')
draw.line([(box[1][0] - 1, box[1][1] + 1), (box[2][0] - 1, box[2][1] + 1)], fill='red')
draw.line([(box[2][0] - 1, box[2][1] + 1), (box[3][0] - 1, box[3][1] + 1)], fill='red')
draw.line([(box[3][0] - 1, box[3][1] + 1), (box[0][0] - 1, box[0][1] + 1)], fill='red')
return img
def get_image_ext(image):
if image.shape[2] == 4:
return ".png"
return ".jpg"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册