提交 227f5f3a 编写于 作者: W WenmuZhou

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleOCR into master

...@@ -189,7 +189,7 @@ PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训 ...@@ -189,7 +189,7 @@ PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训
请扫描下面二维码,完成问卷填写,获取加群二维码和OCR方向的炼丹秘籍 请扫描下面二维码,完成问卷填写,获取加群二维码和OCR方向的炼丹秘籍
<div align="center"> <div align="center">
<img src="./doc/joinus.jpg" width = "200" height = "200" /> <img src="./doc/joinus.PNG" width = "200" height = "200" />
</div> </div>
<a name="许可证书"></a> <a name="许可证书"></a>
......
...@@ -56,7 +56,6 @@ Mobile DEMO experience (based on EasyEdge and Paddle-Lite, supports iOS and Andr ...@@ -56,7 +56,6 @@ Mobile DEMO experience (based on EasyEdge and Paddle-Lite, supports iOS and Andr
- Algorithm introduction - Algorithm introduction
- [Text Detection Algorithm](#TEXTDETECTIONALGORITHM) - [Text Detection Algorithm](#TEXTDETECTIONALGORITHM)
- [Text Recognition Algorithm](#TEXTRECOGNITIONALGORITHM) - [Text Recognition Algorithm](#TEXTRECOGNITIONALGORITHM)
- [END-TO-END OCR Algorithm](#ENDENDOCRALGORITHM)
- Model training/evaluation - Model training/evaluation
- [Text Detection](./doc/doc_en/detection_en.md) - [Text Detection](./doc/doc_en/detection_en.md)
- [Text Recognition](./doc/doc_en/recognition_en.md) - [Text Recognition](./doc/doc_en/recognition_en.md)
...@@ -158,10 +157,6 @@ We use [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/ ...@@ -158,10 +157,6 @@ We use [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./doc/doc_en/recognition_en.md) Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./doc/doc_en/recognition_en.md)
<a name="ENDENDOCRALGORITHM"></a>
## END-TO-END OCR Algorithm
- [ ] [End2End-PSL](https://arxiv.org/abs/1909.07808)(Baidu Self-Research, coming soon)
## Visualization ## Visualization
<a name="UCOCRVIS"></a> <a name="UCOCRVIS"></a>
...@@ -211,7 +206,7 @@ Please refer to the document for training guide and use of PaddleOCR text recogn ...@@ -211,7 +206,7 @@ Please refer to the document for training guide and use of PaddleOCR text recogn
Scan the QR code below with your wechat and completing the questionnaire, you can access to offical technical exchange group. Scan the QR code below with your wechat and completing the questionnaire, you can access to offical technical exchange group.
<div align="center"> <div align="center">
<img src="./doc/joinus.jpg" width = "200" height = "200" /> <img src="./doc/joinus.PNG" width = "200" height = "200" />
</div> </div>
<a name="LICENSE"></a> <a name="LICENSE"></a>
......
...@@ -26,6 +26,8 @@ void DBDetector::LoadModel(const std::string &model_dir) { ...@@ -26,6 +26,8 @@ void DBDetector::LoadModel(const std::string &model_dir) {
config.DisableGpu(); config.DisableGpu();
if (this->use_mkldnn_) { if (this->use_mkldnn_) {
config.EnableMKLDNN(); config.EnableMKLDNN();
// cache 10 different shapes for mkldnn to avoid memory leak
config.SetMkldnnCacheCapacity(10);
} }
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_); config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
} }
......
...@@ -126,6 +126,8 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { ...@@ -126,6 +126,8 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
config.DisableGpu(); config.DisableGpu();
if (this->use_mkldnn_) { if (this->use_mkldnn_) {
config.EnableMKLDNN(); config.EnableMKLDNN();
// cache 10 different shapes for mkldnn to avoid memory leak
config.SetMkldnnCacheCapacity(10);
} }
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_); config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
} }
......
...@@ -3,7 +3,7 @@ use_gpu 0 ...@@ -3,7 +3,7 @@ use_gpu 0
gpu_id 0 gpu_id 0
gpu_mem 4000 gpu_mem 4000
cpu_math_library_num_threads 10 cpu_math_library_num_threads 10
use_mkldnn 0 use_mkldnn 1
use_zero_copy_run 1 use_zero_copy_run 1
# det config # det config
......
...@@ -20,7 +20,7 @@ git clone https://github.com/PaddlePaddle/PaddleOCR.git ...@@ -20,7 +20,7 @@ git clone https://github.com/PaddlePaddle/PaddleOCR.git
``` ```
b. Goto Dockerfile directory(ps:Need to distinguish between cpu and gpu version, the following takes cpu as an example, gpu version needs to replace the keyword) b. Goto Dockerfile directory(ps:Need to distinguish between cpu and gpu version, the following takes cpu as an example, gpu version needs to replace the keyword)
``` ```
cd docker/cpu cd deploy/docker/cpu
``` ```
c. Build image c. Build image
``` ```
......
...@@ -20,7 +20,7 @@ git clone https://github.com/PaddlePaddle/PaddleOCR.git ...@@ -20,7 +20,7 @@ git clone https://github.com/PaddlePaddle/PaddleOCR.git
``` ```
b.切换至Dockerfile目录(注:需要区分cpu或gpu版本,下文以cpu为例,gpu版本需要替换一下关键字即可) b.切换至Dockerfile目录(注:需要区分cpu或gpu版本,下文以cpu为例,gpu版本需要替换一下关键字即可)
``` ```
cd docker/cpu cd deploy/docker/cpu
``` ```
c.生成镜像 c.生成镜像
``` ```
......
> 运行示例前请先安装1.2.0或更高版本PaddleSlim
# 模型量化压缩教程
## 概述
该示例使用PaddleSlim提供的[量化压缩API](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/)对OCR模型进行压缩。
在阅读该示例前,建议您先了解以下内容:
- [OCR模型的常规训练方法](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/detection.md)
- [PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/)
## 安装PaddleSlim
可按照[PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/)中的步骤安装PaddleSlim。
## 量化训练
进入PaddleOCR根目录,通过以下命令对模型进行量化:
```bash
python deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=det_mv3_db/best_accuracy Global.save_model_dir=./output/quant_model
```
## 评估并导出
在得到量化训练保存的模型后,我们可以将其导出为inference_model,用于预测部署:
```bash
python deploy/slim/quantization/export_model.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=output/quant_model/best_accuracy Global.save_model_dir=./output/quant_model
```
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..', '..', '..')))
sys.path.append(
os.path.abspath(os.path.join(__dir__, '..', '..', '..', 'tools')))
def set_paddle_flags(**kwargs):
for key, value in kwargs.items():
if os.environ.get(key, None) is None:
os.environ[key] = str(value)
# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
# not take any effect.
set_paddle_flags(
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
)
import program
from paddle import fluid
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from ppocr.utils.save_load import init_model, load_params
from ppocr.utils.character import CharacterOps
from ppocr.utils.utility import create_module
from ppocr.data.reader_main import reader_main
from paddleslim.quant import quant_aware, convert
from paddle.fluid.layer_helper import LayerHelper
from eval_utils.eval_det_utils import eval_det_run
from eval_utils.eval_rec_utils import eval_rec_run
def main():
# 1. quantization configs
quant_config = {
# weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8,
# ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000,
# The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9,
}
startup_prog, eval_program, place, config, alg_type = program.preprocess()
feeded_var_names, target_vars, fetches_var_name = program.build_export(
config, eval_program, startup_prog)
eval_program = eval_program.clone(for_test=True)
exe = fluid.Executor(place)
exe.run(startup_prog)
eval_program = quant_aware(
eval_program, place, quant_config, scope=None, for_test=True)
init_model(config, eval_program, exe)
# 2. Convert the program before save inference program
# The dtype of eval_program's weights is float32, but in int8 range.
eval_program = convert(eval_program, place, quant_config, scope=None)
eval_fetch_name_list = fetches_var_name
eval_fetch_varname_list = [v.name for v in target_vars]
eval_reader = reader_main(config=config, mode="eval")
quant_info_dict = {'program':eval_program,\
'reader':eval_reader,\
'fetch_name_list':eval_fetch_name_list,\
'fetch_varname_list':eval_fetch_varname_list}
if alg_type == 'det':
final_metrics = eval_det_run(exe, config, quant_info_dict, "eval")
else:
final_metrics = eval_rec_run(exe, config, quant_info_dict, "eval")
print(final_metrics)
# 3. Save inference model
model_path = "./quant_model"
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_inference_model(
dirname=model_path,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
executor=exe,
main_program=eval_program,
model_filename=model_path + '/model',
params_filename=model_path + '/params')
print("model saved as {}".format(model_path))
if __name__ == '__main__':
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..', '..', '..')))
sys.path.append(
os.path.abspath(os.path.join(__dir__, '..', '..', '..', 'tools')))
def set_paddle_flags(**kwargs):
for key, value in kwargs.items():
if os.environ.get(key, None) is None:
os.environ[key] = str(value)
# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
# not take any effect.
set_paddle_flags(
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
)
import tools.program as program
from paddle import fluid
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from ppocr.data.reader_main import reader_main
from ppocr.utils.save_load import init_model
from paddle.fluid.contrib.model_stat import summary
# quant dependencies
import paddle
import paddle.fluid as fluid
from paddleslim.quant import quant_aware, convert
from paddle.fluid.layer_helper import LayerHelper
def pact(x):
"""
Process a variable using the pact method you define
Args:
x(Tensor): Paddle Tensor, need to be preprocess before quantization
Returns:
The processed Tensor x.
"""
helper = LayerHelper("pact", **locals())
dtype = 'float32'
init_thres = 20
u_param_attr = fluid.ParamAttr(
name=x.name + '_pact',
initializer=fluid.initializer.ConstantInitializer(value=init_thres),
regularizer=fluid.regularizer.L2Decay(0.0001),
learning_rate=1)
u_param = helper.create_parameter(attr=u_param_attr, shape=[1], dtype=dtype)
x = fluid.layers.elementwise_sub(
x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param)))
x = fluid.layers.elementwise_add(
x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x)))
return x
def get_optimizer():
"""
Build a program using a model and an optimizer
"""
return fluid.optimizer.AdamOptimizer(0.001)
def main():
train_build_outputs = program.build(
config, train_program, startup_program, mode='train')
train_loader = train_build_outputs[0]
train_fetch_name_list = train_build_outputs[1]
train_fetch_varname_list = train_build_outputs[2]
train_opt_loss_name = train_build_outputs[3]
model_average = train_build_outputs[-1]
eval_program = fluid.Program()
eval_build_outputs = program.build(
config, eval_program, startup_program, mode='eval')
eval_fetch_name_list = eval_build_outputs[1]
eval_fetch_varname_list = eval_build_outputs[2]
eval_program = eval_program.clone(for_test=True)
train_reader = reader_main(config=config, mode="train")
train_loader.set_sample_list_generator(train_reader, places=place)
eval_reader = reader_main(config=config, mode="eval")
exe = fluid.Executor(place)
exe.run(startup_program)
# 1. quantization configs
quant_config = {
# weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8,
# ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000,
# The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9,
}
# 2. quantization transform programs (training aware)
# Make some quantization transforms in the graph before training and testing.
# According to the weight and activation quantization type, the graph will be added
# some fake quantize operators and fake dequantize operators.
act_preprocess_func = pact
optimizer_func = get_optimizer
executor = exe
eval_program = quant_aware(
eval_program,
place,
quant_config,
scope=None,
act_preprocess_func=act_preprocess_func,
optimizer_func=optimizer_func,
executor=executor,
for_test=True)
quant_train_program = quant_aware(
train_program,
place,
quant_config,
scope=None,
act_preprocess_func=act_preprocess_func,
optimizer_func=optimizer_func,
executor=executor,
for_test=False,
return_program=True)
# compile program for multi-devices
train_compile_program = program.create_multi_devices_program(
quant_train_program, train_opt_loss_name, for_quant=True)
init_model(config, quant_train_program, exe)
train_info_dict = {'compile_program':train_compile_program,\
'train_program':quant_train_program,\
'reader':train_loader,\
'fetch_name_list':train_fetch_name_list,\
'fetch_varname_list':train_fetch_varname_list,\
'model_average': model_average}
eval_info_dict = {'program':eval_program,\
'reader':eval_reader,\
'fetch_name_list':eval_fetch_name_list,\
'fetch_varname_list':eval_fetch_varname_list}
if train_alg_type == 'det':
program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict)
else:
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
if __name__ == '__main__':
startup_program, train_program, place, config, train_alg_type = program.preprocess(
)
main()
...@@ -140,7 +140,7 @@ PaddleOCR提供了多种数据增强方式,如果您希望在训练时加入 ...@@ -140,7 +140,7 @@ PaddleOCR提供了多种数据增强方式,如果您希望在训练时加入
训练过程中每种扰动方式以50%的概率被选择,具体代码实现请参考:[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py) 训练过程中每种扰动方式以50%的概率被选择,具体代码实现请参考:[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py)
*由于OpenCV的兼容性问题,扰动操作暂时只支持GPU* *由于OpenCV的兼容性问题,扰动操作暂时只支持Linux*
- 训练 - 训练
......
...@@ -61,6 +61,14 @@ hub install deploy\hubserving\ocr_rec\ ...@@ -61,6 +61,14 @@ hub install deploy\hubserving\ocr_rec\
hub install deploy\hubserving\ocr_system\ hub install deploy\hubserving\ocr_system\
``` ```
#### 安装模型
安装服务模块前,需要将训练好的模型放到对应的文件夹内。默认使用的是:
./inference/ch_det_mv3_db/
./inference/ch_rec_mv3_crnn/
这两个模型可以在https://github.com/PaddlePaddle/PaddleOCR 下载
可以在./deploy/hubserving/ocr_system/params.py 里面修改成自己的模型
### 3. 启动服务 ### 3. 启动服务
#### 方式1. 命令行命令启动(仅支持CPU) #### 方式1. 命令行命令启动(仅支持CPU)
**启动命令:** **启动命令:**
......
English | [简体中文](README_cn.md)
## Introduction
Many user hopes package the PaddleOCR service into an docker image, so that it can be quickly released and used in the docker or k8s environment.
This page provide some standardized code to achieve this goal. You can quickly publish the PaddleOCR project into a callable Restful API service through the following steps. (At present, the deployment based on the HubServing mode is implemented first, and author plans to increase the deployment of the PaddleServing mode in the futrue)
## 1. Prerequisites
You need to install the following basic components first:
a. Docker
b. Graphics driver and CUDA 10.0+(GPU)
c. NVIDIA Container Toolkit(GPU,Docker 19.03+ can skip this)
d. cuDNN 7.6+(GPU)
## 2. Build Image
a. Download PaddleOCR sourcecode
```
git clone https://github.com/PaddlePaddle/PaddleOCR.git
```
b. Goto Dockerfile directory(ps:Need to distinguish between cpu and gpu version, the following takes cpu as an example, gpu version needs to replace the keyword)
```
cd docker/cpu
```
c. Build image
```
docker build -t paddleocr:cpu .
```
## 3. Start container
a. CPU version
```
sudo docker run -dp 8866:8866 --name paddle_ocr paddleocr:cpu
```
b. GPU version (base on NVIDIA Container Toolkit)
```
sudo nvidia-docker run -dp 8866:8866 --name paddle_ocr paddleocr:gpu
```
c. GPU version (Docker 19.03++)
```
sudo docker run -dp 8866:8866 --gpus all --name paddle_ocr paddleocr:gpu
```
d. Check service status(If you can see the following statement then it means completed:Successfully installed ocr_system && Running on http://0.0.0.0:8866/)
```
docker logs -f paddle_ocr
```
## 4. Test
a. Calculate the Base64 encoding of the picture to be recognized (if you just test, you can use a free online tool, like:https://freeonlinetools24.com/base64-image/)
b. Post a service request(sample request in sample_request.txt)
```
curl -H "Content-Type:application/json" -X POST --data "{\"images\": [\"Input image Base64 encode(need to delete the code 'data:image/jpg;base64,')\"]}" http://localhost:8866/predict/ocr_system
```
c. Get resposne(If the call is successful, the following result will be returned)
```
{"msg":"","results":[[{"confidence":0.8403433561325073,"text":"约定","text_region":[[345,377],[641,390],[634,540],[339,528]]},{"confidence":0.8131805658340454,"text":"最终相遇","text_region":[[356,532],[624,530],[624,596],[356,598]]}]],"status":"0"}
```
...@@ -121,24 +121,22 @@ def RandomCropData(data, size): ...@@ -121,24 +121,22 @@ def RandomCropData(data, size):
all_care_polys = [ all_care_polys = [
text_polys[i] for i, tag in enumerate(ignore_tags) if not tag text_polys[i] for i, tag in enumerate(ignore_tags) if not tag
] ]
# 计算crop区域
crop_x, crop_y, crop_w, crop_h = crop_area(im, all_care_polys, crop_x, crop_y, crop_w, crop_h = crop_area(im, all_care_polys,
min_crop_side_ratio, max_tries) min_crop_side_ratio, max_tries)
# crop 图片 保持比例填充 dh, dw = size
scale_w = size[0] / crop_w scale_w = dw / crop_w
scale_h = size[1] / crop_h scale_h = dh / crop_h
scale = min(scale_w, scale_h) scale = min(scale_w, scale_h)
h = int(crop_h * scale) h = int(crop_h * scale)
w = int(crop_w * scale) w = int(crop_w * scale)
if keep_ratio: if keep_ratio:
padimg = np.zeros((size[1], size[0], im.shape[2]), im.dtype) padimg = np.zeros((dh, dw, im.shape[2]), im.dtype)
padimg[:h, :w] = cv2.resize( padimg[:h, :w] = cv2.resize(
im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
img = padimg img = padimg
else: else:
img = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], img = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w],
tuple(size)) (dw, dh))
# crop 文本框
text_polys_crop = [] text_polys_crop = []
ignore_tags_crop = [] ignore_tags_crop = []
texts_crop = [] texts_crop = []
......
...@@ -67,6 +67,7 @@ class DetModel(object): ...@@ -67,6 +67,7 @@ class DetModel(object):
image = fluid.layers.data( image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32') name='image', shape=image_shape, dtype='float32')
image.stop_gradient = False
if mode == "train": if mode == "train":
if self.algorithm == "EAST": if self.algorithm == "EAST":
h, w = int(image_shape[1] // 4), int(image_shape[2] // 4) h, w = int(image_shape[1] // 4), int(image_shape[2] // 4)
...@@ -108,7 +109,10 @@ class DetModel(object): ...@@ -108,7 +109,10 @@ class DetModel(object):
name='tvo', shape=[9, 128, 128], dtype='float32') name='tvo', shape=[9, 128, 128], dtype='float32')
input_tco = fluid.layers.data( input_tco = fluid.layers.data(
name='tco', shape=[3, 128, 128], dtype='float32') name='tco', shape=[3, 128, 128], dtype='float32')
feed_list = [image, input_score, input_border, input_mask, input_tvo, input_tco] feed_list = [
image, input_score, input_border, input_mask, input_tvo,
input_tco
]
labels = {'input_score': input_score,\ labels = {'input_score': input_score,\
'input_border': input_border,\ 'input_border': input_border,\
'input_mask': input_mask,\ 'input_mask': input_mask,\
......
...@@ -68,6 +68,7 @@ class RecModel(object): ...@@ -68,6 +68,7 @@ class RecModel(object):
image_shape.insert(0, -1) image_shape.insert(0, -1)
if mode == "train": if mode == "train":
image = fluid.data(name='image', shape=image_shape, dtype='float32') image = fluid.data(name='image', shape=image_shape, dtype='float32')
image.stop_gradient = False
if self.loss_type == "attention": if self.loss_type == "attention":
label_in = fluid.data( label_in = fluid.data(
name='label_in', name='label_in',
...@@ -136,7 +137,7 @@ class RecModel(object): ...@@ -136,7 +137,7 @@ class RecModel(object):
else: else:
labels = None labels = None
loader = None loader = None
if self.char_type == "ch" and self.infer_img: if self.char_type == "ch" and self.infer_img and self.loss_type != "srn":
image_shape[-1] = -1 image_shape[-1] = -1
if self.tps != None: if self.tps != None:
logger.info( logger.info(
...@@ -146,6 +147,7 @@ class RecModel(object): ...@@ -146,6 +147,7 @@ class RecModel(object):
) )
image_shape = deepcopy(self.image_shape) image_shape = deepcopy(self.image_shape)
image = fluid.data(name='image', shape=image_shape, dtype='float32') image = fluid.data(name='image', shape=image_shape, dtype='float32')
image.stop_gradient = False
if self.loss_type == "srn": if self.loss_type == "srn":
encoder_word_pos = fluid.data( encoder_word_pos = fluid.data(
name="encoder_word_pos", name="encoder_word_pos",
...@@ -172,16 +174,13 @@ class RecModel(object): ...@@ -172,16 +174,13 @@ class RecModel(object):
self.max_text_length self.max_text_length
], ],
dtype="float32") dtype="float32")
feed_list = [
image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
labels = { labels = {
'encoder_word_pos': encoder_word_pos, 'encoder_word_pos': encoder_word_pos,
'gsrm_word_pos': gsrm_word_pos, 'gsrm_word_pos': gsrm_word_pos,
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1,
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2 'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2
} }
return image, labels, loader return image, labels, loader
def __call__(self, mode): def __call__(self, mode):
...@@ -218,8 +217,13 @@ class RecModel(object): ...@@ -218,8 +217,13 @@ class RecModel(object):
if self.loss_type == "ctc": if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict) predict = fluid.layers.softmax(predict)
if self.loss_type == "srn": if self.loss_type == "srn":
raise Exception( return [
"Warning! SRN does not support export model currently") image, labels, {
'decoded_out': decoded_out,
'predicts': predict
}
]
return [image, {'decoded_out': decoded_out, 'predicts': predict}] return [image, {'decoded_out': decoded_out, 'predicts': predict}]
else: else:
predict = predicts['predict'] predict = predicts['predict']
......
...@@ -35,6 +35,7 @@ class CTCPredict(object): ...@@ -35,6 +35,7 @@ class CTCPredict(object):
self.fc_decay = params.get("fc_decay", 0.0004) self.fc_decay = params.get("fc_decay", 0.0004)
def __call__(self, inputs, labels=None, mode=None): def __call__(self, inputs, labels=None, mode=None):
with fluid.scope_guard("skip_quant"):
encoder_features = self.encoder(inputs) encoder_features = self.encoder(inputs)
if self.encoder_type != "reshape": if self.encoder_type != "reshape":
encoder_features = fluid.layers.concat(encoder_features, axis=1) encoder_features = fluid.layers.concat(encoder_features, axis=1)
......
...@@ -90,15 +90,3 @@ def check_and_read_gif(img_path): ...@@ -90,15 +90,3 @@ def check_and_read_gif(img_path):
return imgvalue, True return imgvalue, True
return None, False return None, False
def create_multi_devices_program(program, loss_var_name):
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = True
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_iteration_per_drop_scope = 1
compile_program = fluid.CompiledProgram(program).with_data_parallel(
loss_name=loss_var_name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
return compile_program
...@@ -40,6 +40,7 @@ class TextRecognizer(object): ...@@ -40,6 +40,7 @@ class TextRecognizer(object):
self.character_type = args.rec_char_type self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm self.rec_algorithm = args.rec_algorithm
self.text_len = args.max_text_length
self.use_zero_copy_run = args.use_zero_copy_run self.use_zero_copy_run = args.use_zero_copy_run
char_ops_params = { char_ops_params = {
"character_type": args.rec_char_type, "character_type": args.rec_char_type,
...@@ -47,12 +48,15 @@ class TextRecognizer(object): ...@@ -47,12 +48,15 @@ class TextRecognizer(object):
"use_space_char": args.use_space_char, "use_space_char": args.use_space_char,
"max_text_length": args.max_text_length "max_text_length": args.max_text_length
} }
if self.rec_algorithm != "RARE": if self.rec_algorithm in ["CRNN", "Rosetta", "STAR-Net"]:
char_ops_params['loss_type'] = 'ctc' char_ops_params['loss_type'] = 'ctc'
self.loss_type = 'ctc' self.loss_type = 'ctc'
else: elif self.rec_algorithm == "RARE":
char_ops_params['loss_type'] = 'attention' char_ops_params['loss_type'] = 'attention'
self.loss_type = 'attention' self.loss_type = 'attention'
elif self.rec_algorithm == "SRN":
char_ops_params['loss_type'] = 'srn'
self.loss_type = 'srn'
self.char_ops = CharacterOps(char_ops_params) self.char_ops = CharacterOps(char_ops_params)
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
...@@ -75,6 +79,83 @@ class TextRecognizer(object): ...@@ -75,6 +79,83 @@ class TextRecognizer(object):
padding_im[:, :, 0:resized_w] = resized_image padding_im[:, :, 0:resized_w] = resized_image
return padding_im return padding_im
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(self, image_shape, num_heads, max_text_length,
char_num):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(
gsrm_slf_attn_bias1,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(
gsrm_slf_attn_bias2,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
encoder_word_pos = encoder_word_pos[np.newaxis, :]
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def process_image_srn(self,
img,
image_shape,
num_heads,
max_text_length,
char_ops=None):
norm_img = self.resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :]
char_num = char_ops.get_char_num()
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
self.srn_other_inputs(image_shape, num_heads, max_text_length, char_num)
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
def __call__(self, img_list): def __call__(self, img_list):
img_num = len(img_list) img_num = len(img_list)
# Calculate the aspect ratio of all text bars # Calculate the aspect ratio of all text bars
...@@ -84,7 +165,7 @@ class TextRecognizer(object): ...@@ -84,7 +165,7 @@ class TextRecognizer(object):
# Sorting can speed up the recognition process # Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list)) indices = np.argsort(np.array(width_list))
# rec_res = [] #rec_res = []
rec_res = [['', 0.0]] * img_num rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num batch_num = self.rec_batch_num
predict_time = 0 predict_time = 0
...@@ -98,13 +179,55 @@ class TextRecognizer(object): ...@@ -98,13 +179,55 @@ class TextRecognizer(object):
wh_ratio = w * 1.0 / h wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio) max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no): for ino in range(beg_img_no, end_img_no):
# norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio) if self.loss_type != "srn":
norm_img = self.resize_norm_img(img_list[indices[ino]], norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio) max_wh_ratio)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img) norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch) else:
norm_img = self.process_image_srn(img_list[indices[ino]],
self.rec_image_shape, 8,
25, self.char_ops)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
encoder_word_pos_list.append(norm_img[1])
gsrm_word_pos_list.append(norm_img[2])
gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0])
norm_img_batch = np.concatenate(norm_img_batch, axis=0)
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
if self.loss_type == "srn":
starttime = time.time()
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(
gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = np.concatenate(
gsrm_slf_attn_bias2_list)
starttime = time.time()
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
encoder_word_pos_list = fluid.core.PaddleTensor(
encoder_word_pos_list)
gsrm_word_pos_list = fluid.core.PaddleTensor(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = fluid.core.PaddleTensor(
gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = fluid.core.PaddleTensor(
gsrm_slf_attn_bias2_list)
inputs = [
norm_img_batch, encoder_word_pos_list,
gsrm_slf_attn_bias1_list, gsrm_slf_attn_bias2_list,
gsrm_word_pos_list
]
self.predictor.run(inputs)
else:
starttime = time.time() starttime = time.time()
if self.use_zero_copy_run: if self.use_zero_copy_run:
self.input_tensor.copy_from_cpu(norm_img_batch) self.input_tensor.copy_from_cpu(norm_img_batch)
...@@ -136,6 +259,26 @@ class TextRecognizer(object): ...@@ -136,6 +259,26 @@ class TextRecognizer(object):
score = np.mean(probs[valid_ind, ind[valid_ind]]) score = np.mean(probs[valid_ind, ind[valid_ind]])
# rec_res.append([preds_text, score]) # rec_res.append([preds_text, score])
rec_res[indices[beg_img_no + rno]] = [preds_text, score] rec_res[indices[beg_img_no + rno]] = [preds_text, score]
elif self.loss_type == 'srn':
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
probs = self.output_tensors[1].copy_to_cpu()
char_num = self.char_ops.get_char_num()
preds = rec_idx_batch.reshape(-1)
elapse = time.time() - starttime
predict_time += elapse
total_preds = preds.copy()
for ino in range(int(len(rec_idx_batch) / self.text_len)):
preds = total_preds[ino * self.text_len:(ino + 1) *
self.text_len]
ind = np.argmax(probs, axis=1)
valid_ind = np.where(preds != int(char_num - 1))[0]
if len(valid_ind) == 0:
continue
score = np.mean(probs[valid_ind, ind[valid_ind]])
preds = preds[:valid_ind[-1] + 1]
preds_text = self.char_ops.decode(preds)
rec_res[indices[beg_img_no + ino]] = [preds_text, score]
else: else:
rec_idx_batch = self.output_tensors[0].copy_to_cpu() rec_idx_batch = self.output_tensors[0].copy_to_cpu()
predict_batch = self.output_tensors[1].copy_to_cpu() predict_batch = self.output_tensors[1].copy_to_cpu()
...@@ -170,6 +313,7 @@ def main(args): ...@@ -170,6 +313,7 @@ def main(args):
continue continue
valid_image_file_list.append(image_file) valid_image_file_list.append(image_file)
img_list.append(img) img_list.append(img)
try: try:
rec_res, predict_time = text_recognizer(img_list) rec_res, predict_time = text_recognizer(img_list)
except Exception as e: except Exception as e:
......
...@@ -122,7 +122,6 @@ def main(args): ...@@ -122,7 +122,6 @@ def main(args):
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
text_sys = TextSystem(args) text_sys = TextSystem(args)
is_visualize = True is_visualize = True
tackle_img_num = 0
for image_file in image_file_list: for image_file in image_file_list:
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
if not flag: if not flag:
...@@ -131,9 +130,6 @@ def main(args): ...@@ -131,9 +130,6 @@ def main(args):
logger.info("error in loading image:{}".format(image_file)) logger.info("error in loading image:{}".format(image_file))
continue continue
starttime = time.time() starttime = time.time()
tackle_img_num += 1
if not args.use_gpu and args.enable_mkldnn and tackle_img_num % 30 == 0:
text_sys = TextSystem(args)
dt_boxes, rec_res = text_sys(img) dt_boxes, rec_res = text_sys(img)
elapse = time.time() - starttime elapse = time.time() - starttime
print("Predict time of %s: %.3fs" % (image_file, elapse)) print("Predict time of %s: %.3fs" % (image_file, elapse))
...@@ -153,11 +149,7 @@ def main(args): ...@@ -153,11 +149,7 @@ def main(args):
scores = [rec_res[i][1] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))]
draw_img = draw_ocr( draw_img = draw_ocr(
image, image, boxes, txts, scores, drop_score=drop_score)
boxes,
txts,
scores,
drop_score=drop_score)
draw_img_save = "./inference_results/" draw_img_save = "./inference_results/"
if not os.path.exists(draw_img_save): if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save) os.makedirs(draw_img_save)
......
...@@ -101,6 +101,8 @@ def create_predictor(args, mode): ...@@ -101,6 +101,8 @@ def create_predictor(args, mode):
config.disable_gpu() config.disable_gpu()
config.set_cpu_math_library_num_threads(6) config.set_cpu_math_library_num_threads(6)
if args.enable_mkldnn: if args.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn() config.enable_mkldnn()
#config.enable_memory_optim() #config.enable_memory_optim()
...@@ -114,7 +116,8 @@ def create_predictor(args, mode): ...@@ -114,7 +116,8 @@ def create_predictor(args, mode):
predictor = create_paddle_predictor(config) predictor = create_paddle_predictor(config)
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
input_tensor = predictor.get_input_tensor(input_names[0]) for name in input_names:
input_tensor = predictor.get_input_tensor(name)
output_names = predictor.get_output_names() output_names = predictor.get_output_names()
output_tensors = [] output_tensors = []
for output_name in output_names: for output_name in output_names:
......
...@@ -145,7 +145,7 @@ def main(): ...@@ -145,7 +145,7 @@ def main():
preds = preds.reshape(-1) preds = preds.reshape(-1)
probs = np.array(predict[1]) probs = np.array(predict[1])
ind = np.argmax(probs, axis=1) ind = np.argmax(probs, axis=1)
valid_ind = np.where(preds != int(char_num-1))[0] valid_ind = np.where(preds != int(char_num - 1))[0]
if len(valid_ind) == 0: if len(valid_ind) == 0:
continue continue
score = np.mean(probs[valid_ind, ind[valid_ind]]) score = np.mean(probs[valid_ind, ind[valid_ind]])
......
...@@ -208,18 +208,29 @@ def build_export(config, main_prog, startup_prog): ...@@ -208,18 +208,29 @@ def build_export(config, main_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
func_infor = config['Architecture']['function'] func_infor = config['Architecture']['function']
model = create_module(func_infor)(params=config) model = create_module(func_infor)(params=config)
algorithm = config['Global']['algorithm']
if algorithm == "SRN":
image, others, outputs = model(mode='export')
else:
image, outputs = model(mode='export') image, outputs = model(mode='export')
fetches_var_name = sorted([name for name in outputs.keys()]) fetches_var_name = sorted([name for name in outputs.keys()])
fetches_var = [outputs[name] for name in fetches_var_name] fetches_var = [outputs[name] for name in fetches_var_name]
if algorithm == "SRN":
others_var_names = sorted([name for name in others.keys()])
feeded_var_names = [image.name] + others_var_names
else:
feeded_var_names = [image.name] feeded_var_names = [image.name]
target_vars = fetches_var target_vars = fetches_var
return feeded_var_names, target_vars, fetches_var_name return feeded_var_names, target_vars, fetches_var_name
def create_multi_devices_program(program, loss_var_name): def create_multi_devices_program(program, loss_var_name, for_quant=False):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False build_strategy.memory_optimize = False
build_strategy.enable_inplace = True build_strategy.enable_inplace = True
if for_quant:
build_strategy.fuse_all_reduce_ops = False
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_iteration_per_drop_scope = 1 exec_strategy.num_iteration_per_drop_scope = 1
compile_program = fluid.CompiledProgram(program).with_data_parallel( compile_program = fluid.CompiledProgram(program).with_data_parallel(
...@@ -409,7 +420,9 @@ def preprocess(): ...@@ -409,7 +420,9 @@ def preprocess():
check_gpu(use_gpu) check_gpu(use_gpu)
alg = config['Global']['algorithm'] alg = config['Global']['algorithm']
assert alg in ['EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN'] assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN'
]
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']: if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
config['Global']['char_ops'] = CharacterOps(config['Global']) config['Global']['char_ops'] = CharacterOps(config['Global'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册