“80c68d38ff3c59e12f48b4e4e88c24c89568fc0a”上不存在“paddle/legacy/math/SparseMatrix.h”
未验证 提交 90ef8d10 编写于 作者: S Steffy-zxf 提交者: GitHub

add enable mkldnn for ocr (#888)

上级 7146b59d
...@@ -19,6 +19,16 @@ $ hub run chinese_ocr_db_crnn_mobile --input_path "/PATH/TO/IMAGE" ...@@ -19,6 +19,16 @@ $ hub run chinese_ocr_db_crnn_mobile --input_path "/PATH/TO/IMAGE"
## API ## API
### \_\_init\_\_(text_detector_module=None, enable_mkldnn=False)
构造ChineseOCRDBCRNN对象
**参数**
* text_detector_module(str): 文字检测PaddleHub Module名字,如设置为None,则默认使用[chinese_text_detection_db_mobile Module](https://www.paddlepaddle.org.cn/hubdetail?name=chinese_text_detection_db_mobile&en_category=TextRecognition)。其作用为检测图片当中的文本。
* enable_mkldnn(bool): 是否开启mkldnn加速CPU计算。该参数仅在CPU运行下设置有效。默认为False。
```python ```python
def recognize_text(images=[], def recognize_text(images=[],
paths=[], paths=[],
......
...@@ -25,14 +25,14 @@ from chinese_ocr_db_crnn_mobile.utils import base64_to_cv2, draw_ocr, get_image_ ...@@ -25,14 +25,14 @@ from chinese_ocr_db_crnn_mobile.utils import base64_to_cv2, draw_ocr, get_image_
@moduleinfo( @moduleinfo(
name="chinese_ocr_db_crnn_mobile", name="chinese_ocr_db_crnn_mobile",
version="1.0.3", version="1.0.4",
summary= summary=
"The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions based on the differentiable_binarization_chn module. Then it recognizes the chinese texts. ", "The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions based on the differentiable_binarization_chn module. Then it recognizes the chinese texts. ",
author="paddle-dev", author="paddle-dev",
author_email="paddle-dev@baidu.com", author_email="paddle-dev@baidu.com",
type="cv/text_recognition") type="cv/text_recognition")
class ChineseOCRDBCRNN(hub.Module): class ChineseOCRDBCRNN(hub.Module):
def _initialize(self, text_detector_module=None): def _initialize(self, text_detector_module=None, enable_mkldnn=False):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
...@@ -49,6 +49,7 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -49,6 +49,7 @@ class ChineseOCRDBCRNN(hub.Module):
self.font_file = os.path.join(self.directory, 'assets', 'simfang.ttf') self.font_file = os.path.join(self.directory, 'assets', 'simfang.ttf')
self.pretrained_model_path = os.path.join(self.directory, self.pretrained_model_path = os.path.join(self.directory,
'inference_model') 'inference_model')
self.enable_mkldnn = enable_mkldnn
self._set_config() self._set_config()
def _set_config(self): def _set_config(self):
...@@ -70,6 +71,8 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -70,6 +71,8 @@ class ChineseOCRDBCRNN(hub.Module):
config.enable_use_gpu(8000, 0) config.enable_use_gpu(8000, 0)
else: else:
config.disable_gpu() config.disable_gpu()
if self.enable_mkldnn:
config.enable_mkldnn()
config.disable_glog_info() config.disable_glog_info()
...@@ -92,7 +95,9 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -92,7 +95,9 @@ class ChineseOCRDBCRNN(hub.Module):
""" """
if not self._text_detector_module: if not self._text_detector_module:
self._text_detector_module = hub.Module( self._text_detector_module = hub.Module(
name='chinese_text_detection_db_mobile') name='chinese_text_detection_db_mobile',
enable_mkldnn=self.enable_mkldnn,
version='1.0.2')
return self._text_detector_module return self._text_detector_module
def read_images(self, paths=[]): def read_images(self, paths=[]):
......
...@@ -19,6 +19,15 @@ $ hub run chinese_ocr_db_crnn_server --input_path "/PATH/TO/IMAGE" ...@@ -19,6 +19,15 @@ $ hub run chinese_ocr_db_crnn_server --input_path "/PATH/TO/IMAGE"
## API ## API
### \_\_init\_\_(text_detector_module=None, enable_mkldnn=False)
构造ChineseOCRDBCRNNServer对象
**参数**
* text_detector_module(str): 文字检测PaddleHub Module名字,如设置为None,则默认使用[chinese_text_detection_db_server Module](https://www.paddlepaddle.org.cn/hubdetail?name=chinese_text_detection_db_server&en_category=TextRecognition)。其作用为检测图片当中的文本。
* enable_mkldnn(bool): 是否开启mkldnn加速CPU计算。该参数仅在CPU运行下设置有效。默认为False。
```python ```python
def recognize_text(images=[], def recognize_text(images=[],
paths=[], paths=[],
......
...@@ -25,14 +25,14 @@ from chinese_ocr_db_crnn_server.utils import base64_to_cv2, draw_ocr, get_image_ ...@@ -25,14 +25,14 @@ from chinese_ocr_db_crnn_server.utils import base64_to_cv2, draw_ocr, get_image_
@moduleinfo( @moduleinfo(
name="chinese_ocr_db_crnn_server", name="chinese_ocr_db_crnn_server",
version="1.0.2", version="1.0.3",
summary= summary=
"The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions based on the differentiable_binarization_chn module. Then it recognizes the chinese texts. ", "The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions based on the differentiable_binarization_chn module. Then it recognizes the chinese texts. ",
author="paddle-dev", author="paddle-dev",
author_email="paddle-dev@baidu.com", author_email="paddle-dev@baidu.com",
type="cv/text_recognition") type="cv/text_recognition")
class ChineseOCRDBCRNNServer(hub.Module): class ChineseOCRDBCRNNServer(hub.Module):
def _initialize(self, text_detector_module=None): def _initialize(self, text_detector_module=None, enable_mkldnn=False):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
...@@ -49,6 +49,8 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -49,6 +49,8 @@ class ChineseOCRDBCRNNServer(hub.Module):
self.font_file = os.path.join(self.directory, 'assets', 'simfang.ttf') self.font_file = os.path.join(self.directory, 'assets', 'simfang.ttf')
self.pretrained_model_path = os.path.join(self.directory, 'assets', self.pretrained_model_path = os.path.join(self.directory, 'assets',
'ch_rec_r34_vd_crnn') 'ch_rec_r34_vd_crnn')
self.enable_mkldnn = enable_mkldnn
self._set_config() self._set_config()
def _set_config(self): def _set_config(self):
...@@ -70,6 +72,8 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -70,6 +72,8 @@ class ChineseOCRDBCRNNServer(hub.Module):
config.enable_use_gpu(8000, 0) config.enable_use_gpu(8000, 0)
else: else:
config.disable_gpu() config.disable_gpu()
if self.enable_mkldnn:
config.enable_mkldnn()
config.disable_glog_info() config.disable_glog_info()
...@@ -92,7 +96,9 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -92,7 +96,9 @@ class ChineseOCRDBCRNNServer(hub.Module):
""" """
if not self._text_detector_module: if not self._text_detector_module:
self._text_detector_module = hub.Module( self._text_detector_module = hub.Module(
name='chinese_text_detection_db_server') name='chinese_text_detection_db_server',
enable_mkldnn=self.enable_mkldnn,
version='1.0.1')
return self._text_detector_module return self._text_detector_module
def read_images(self, paths=[]): def read_images(self, paths=[]):
...@@ -423,7 +429,7 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -423,7 +429,7 @@ class ChineseOCRDBCRNNServer(hub.Module):
if __name__ == '__main__': if __name__ == '__main__':
ocr = ChineseOCRDBCRNNServer() ocr = ChineseOCRDBCRNNServer(enable_mkldnn=True)
print(ocr.name) print(ocr.name)
image_path = [ image_path = [
'/mnt/zhangxuefei/PaddleOCR/doc/imgs/11.jpg', '/mnt/zhangxuefei/PaddleOCR/doc/imgs/11.jpg',
......
...@@ -19,6 +19,17 @@ $ hub run chinese_text_detection_db_mobile --input_path "/PATH/TO/IMAGE" ...@@ -19,6 +19,17 @@ $ hub run chinese_text_detection_db_mobile --input_path "/PATH/TO/IMAGE"
## API ## API
## API
### \_\_init\_\_(enable_mkldnn=False)
构造ChineseTextDetectionDB对象
**参数**
* enable_mkldnn(bool): 是否开启mkldnn加速CPU计算。该参数仅在CPU运行下设置有效。默认为False。
```python ```python
def detect_text(paths=[], def detect_text(paths=[],
images=[], images=[],
...@@ -51,7 +62,7 @@ def detect_text(paths=[], ...@@ -51,7 +62,7 @@ def detect_text(paths=[],
import paddlehub as hub import paddlehub as hub
import cv2 import cv2
text_detector = hub.Module(name="chinese_text_detection_db_mobile") text_detector = hub.Module(name="chinese_text_detection_db_mobile", enable_mk)
result = text_detector.detect_text(images=[cv2.imread('/PATH/TO/IMAGE')]) result = text_detector.detect_text(images=[cv2.imread('/PATH/TO/IMAGE')])
# or # or
......
...@@ -29,19 +29,21 @@ def base64_to_cv2(b64str): ...@@ -29,19 +29,21 @@ def base64_to_cv2(b64str):
@moduleinfo( @moduleinfo(
name="chinese_text_detection_db_mobile", name="chinese_text_detection_db_mobile",
version="1.0.1", version="1.0.2",
summary= summary=
"The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.", "The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.",
author="paddle-dev", author="paddle-dev",
author_email="paddle-dev@baidu.com", author_email="paddle-dev@baidu.com",
type="cv/text_recognition") type="cv/text_recognition")
class ChineseTextDetectionDB(hub.Module): class ChineseTextDetectionDB(hub.Module):
def _initialize(self): def _initialize(self, enable_mkldnn=False):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
self.pretrained_model_path = os.path.join(self.directory, self.pretrained_model_path = os.path.join(self.directory,
'inference_model') 'inference_model')
self.enable_mkldnn = enable_mkldnn
self._set_config() self._set_config()
def check_requirements(self): def check_requirements(self):
...@@ -71,6 +73,8 @@ class ChineseTextDetectionDB(hub.Module): ...@@ -71,6 +73,8 @@ class ChineseTextDetectionDB(hub.Module):
config.enable_use_gpu(8000, 0) config.enable_use_gpu(8000, 0)
else: else:
config.disable_gpu() config.disable_gpu()
if self.enable_mkldnn:
config.enable_mkldnn()
config.disable_glog_info() config.disable_glog_info()
......
...@@ -19,6 +19,14 @@ $ hub run chinese_text_detection_db_server --input_path "/PATH/TO/IMAGE" ...@@ -19,6 +19,14 @@ $ hub run chinese_text_detection_db_server --input_path "/PATH/TO/IMAGE"
## API ## API
### \_\_init\_\_(enable_mkldnn=False)
构造ChineseTextDetectionDBServer对象
**参数**
* enable_mkldnn(bool): 是否开启mkldnn加速CPU计算。该参数仅在CPU运行下设置有效。默认为False。
```python ```python
def detect_text(paths=[], def detect_text(paths=[],
images=[], images=[],
......
...@@ -29,19 +29,21 @@ def base64_to_cv2(b64str): ...@@ -29,19 +29,21 @@ def base64_to_cv2(b64str):
@moduleinfo( @moduleinfo(
name="chinese_text_detection_db_server", name="chinese_text_detection_db_server",
version="1.0.0", version="1.0.1",
summary= summary=
"The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.", "The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.",
author="paddle-dev", author="paddle-dev",
author_email="paddle-dev@baidu.com", author_email="paddle-dev@baidu.com",
type="cv/text_recognition") type="cv/text_recognition")
class ChineseTextDetectionDBServer(hub.Module): class ChineseTextDetectionDBServer(hub.Module):
def _initialize(self): def _initialize(self, enable_mkldnn=False):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
self.pretrained_model_path = os.path.join(self.directory, self.pretrained_model_path = os.path.join(self.directory,
'ch_det_r50_vd_db') 'ch_det_r50_vd_db')
self.enable_mkldnn = enable_mkldnn
self._set_config() self._set_config()
def check_requirements(self): def check_requirements(self):
...@@ -71,6 +73,8 @@ class ChineseTextDetectionDBServer(hub.Module): ...@@ -71,6 +73,8 @@ class ChineseTextDetectionDBServer(hub.Module):
config.enable_use_gpu(8000, 0) config.enable_use_gpu(8000, 0)
else: else:
config.disable_gpu() config.disable_gpu()
if self.enable_mkldnn:
config.enable_mkldnn()
config.disable_glog_info() config.disable_glog_info()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册