Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
0005f4d1
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0005f4d1
编写于
9月 23, 2020
作者:
W
wangjiawei04
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add inference to serving model tool
上级
e53c4273
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
173 addition
and
208 deletion
+173
-208
deploy/pdserving/clas_local_server.py
deploy/pdserving/clas_local_server.py
+2
-2
deploy/pdserving/clas_rpc_server.py
deploy/pdserving/clas_rpc_server.py
+2
-2
deploy/pdserving/det_local_server.py
deploy/pdserving/det_local_server.py
+2
-2
deploy/pdserving/det_rpc_server.py
deploy/pdserving/det_rpc_server.py
+2
-2
deploy/pdserving/ocr_local_server.py
deploy/pdserving/ocr_local_server.py
+4
-4
deploy/pdserving/ocr_rpc_server.py
deploy/pdserving/ocr_rpc_server.py
+3
-4
deploy/pdserving/params.py
deploy/pdserving/params.py
+50
-0
deploy/pdserving/rec_local_server.py
deploy/pdserving/rec_local_server.py
+3
-2
deploy/pdserving/rec_rpc_server.py
deploy/pdserving/rec_rpc_server.py
+2
-2
doc/doc_ch/serving_inference.md
doc/doc_ch/serving_inference.md
+67
-100
ppocr/data/det/db_process.py
ppocr/data/det/db_process.py
+4
-6
tools/export_serving_model.py
tools/export_serving_model.py
+0
-78
tools/infer/predict_cls.py
tools/infer/predict_cls.py
+1
-1
tools/infer/predict_det.py
tools/infer/predict_det.py
+1
-1
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+1
-1
tools/infer/utility.py
tools/infer/utility.py
+0
-1
tools/inference_to_serving.py
tools/inference_to_serving.py
+29
-0
未找到文件。
deploy/pdserving/clas_local_server.py
浏览文件 @
0005f4d1
...
...
@@ -22,9 +22,9 @@ import time
import
re
import
base64
from
tools.infer.predict_cls
import
TextClassifier
import
tools.infer.utility
as
utility
from
params
import
read_params
global_args
=
utility
.
parse_arg
s
()
global_args
=
read_param
s
()
if
global_args
.
use_gpu
:
from
paddle_serving_server_gpu.web_service
import
WebService
else
:
...
...
deploy/pdserving/clas_rpc_server.py
浏览文件 @
0005f4d1
...
...
@@ -22,9 +22,9 @@ import time
import
re
import
base64
from
tools.infer.predict_cls
import
TextClassifier
import
tools.infer.utility
as
utility
from
params
import
read_params
global_args
=
utility
.
parse_arg
s
()
global_args
=
read_param
s
()
if
global_args
.
use_gpu
:
from
paddle_serving_server_gpu.web_service
import
WebService
else
:
...
...
deploy/pdserving/det_local_server.py
浏览文件 @
0005f4d1
...
...
@@ -21,9 +21,9 @@ import time
import
re
import
base64
from
tools.infer.predict_det
import
TextDetector
import
tools.infer.utility
as
utility
from
params
import
read_params
global_args
=
utility
.
parse_arg
s
()
global_args
=
read_param
s
()
if
global_args
.
use_gpu
:
from
paddle_serving_server_gpu.web_service
import
WebService
else
:
...
...
deploy/pdserving/det_rpc_server.py
浏览文件 @
0005f4d1
...
...
@@ -21,9 +21,9 @@ import time
import
re
import
base64
from
tools.infer.predict_det
import
TextDetector
import
tools.infer.utility
as
utility
from
params
import
read_params
global_args
=
utility
.
parse_arg
s
()
global_args
=
read_param
s
()
if
global_args
.
use_gpu
:
from
paddle_serving_server_gpu.web_service
import
WebService
else
:
...
...
deploy/pdserving/ocr_local_server.py
浏览文件 @
0005f4d1
...
...
@@ -24,12 +24,13 @@ import base64
from
clas_local_server
import
TextClassifierHelper
from
det_local_server
import
TextDetectorHelper
from
rec_local_server
import
TextRecognizerHelper
import
tools.infer.utility
as
utility
from
tools.infer.predict_system
import
TextSystem
,
sorted_boxes
from
paddle_serving_app.local_predict
import
Debugger
import
copy
from
params
import
read_params
global_args
=
read_params
()
global_args
=
utility
.
parse_args
()
if
global_args
.
use_gpu
:
from
paddle_serving_server_gpu.web_service
import
WebService
else
:
...
...
@@ -84,8 +85,7 @@ class TextSystemHelper(TextSystem):
class
OCRService
(
WebService
):
def
init_rec
(
self
):
args
=
utility
.
parse_args
()
self
.
text_system
=
TextSystemHelper
(
args
)
self
.
text_system
=
TextSystemHelper
(
global_args
)
def
preprocess
(
self
,
feed
=
[],
fetch
=
[]):
# TODO: to handle batch rec images
...
...
deploy/pdserving/ocr_rpc_server.py
浏览文件 @
0005f4d1
...
...
@@ -24,11 +24,11 @@ import base64
from
clas_rpc_server
import
TextClassifierHelper
from
det_rpc_server
import
TextDetectorHelper
from
rec_rpc_server
import
TextRecognizerHelper
import
tools.infer.utility
as
utility
from
tools.infer.predict_system
import
TextSystem
,
sorted_boxes
import
copy
from
params
import
read_params
global_args
=
utility
.
parse_arg
s
()
global_args
=
read_param
s
()
if
global_args
.
use_gpu
:
from
paddle_serving_server_gpu.web_service
import
WebService
else
:
...
...
@@ -87,8 +87,7 @@ class TextSystemHelper(TextSystem):
class
OCRService
(
WebService
):
def
init_rec
(
self
):
args
=
utility
.
parse_args
()
self
.
text_system
=
TextSystemHelper
(
args
)
self
.
text_system
=
TextSystemHelper
(
global_args
)
def
preprocess
(
self
,
feed
=
[],
fetch
=
[]):
# TODO: to handle batch rec images
...
...
deploy/pdserving/params.py
0 → 100644
浏览文件 @
0005f4d1
# -*- coding:utf-8 -*-
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
class
Config
(
object
):
pass
def
read_params
():
cfg
=
Config
()
#use gpu
cfg
.
use_gpu
=
False
cfg
.
use_pdserving
=
True
#params for text detector
cfg
.
det_algorithm
=
"DB"
cfg
.
det_model_dir
=
"./det_mv_server/"
cfg
.
det_max_side_len
=
960
#DB parmas
cfg
.
det_db_thresh
=
0.3
cfg
.
det_db_box_thresh
=
0.5
cfg
.
det_db_unclip_ratio
=
2.0
#EAST parmas
cfg
.
det_east_score_thresh
=
0.8
cfg
.
det_east_cover_thresh
=
0.1
cfg
.
det_east_nms_thresh
=
0.2
#params for text recognizer
cfg
.
rec_algorithm
=
"CRNN"
cfg
.
rec_model_dir
=
"./ocr_rec_server/"
cfg
.
rec_image_shape
=
"3, 32, 320"
cfg
.
rec_char_type
=
'ch'
cfg
.
rec_batch_num
=
30
cfg
.
max_text_length
=
25
cfg
.
rec_char_dict_path
=
"./ppocr_keys_v1.txt"
cfg
.
use_space_char
=
True
#params for text classifier
cfg
.
use_angle_cls
=
True
cfg
.
cls_model_dir
=
"./ocr_clas_server/"
cfg
.
cls_image_shape
=
"3, 48, 192"
cfg
.
label_list
=
[
'0'
,
'180'
]
cfg
.
cls_batch_num
=
30
cfg
.
cls_thresh
=
0.9
return
cfg
deploy/pdserving/rec_local_server.py
浏览文件 @
0005f4d1
...
...
@@ -22,9 +22,10 @@ import time
import
re
import
base64
from
tools.infer.predict_rec
import
TextRecognizer
import
tools.infer.utility
as
utility
from
params
import
read_params
global_args
=
read_params
()
global_args
=
utility
.
parse_args
()
if
global_args
.
use_gpu
:
from
paddle_serving_server_gpu.web_service
import
WebService
else
:
...
...
deploy/pdserving/rec_rpc_server.py
浏览文件 @
0005f4d1
...
...
@@ -22,9 +22,9 @@ import time
import
re
import
base64
from
tools.infer.predict_rec
import
TextRecognizer
import
tools.infer.utility
as
utility
from
params
import
read_params
global_args
=
utility
.
parse_arg
s
()
global_args
=
read_param
s
()
if
global_args
.
use_gpu
:
from
paddle_serving_server_gpu.web_service
import
WebService
else
:
...
...
doc/doc_ch/serving_inference.md
浏览文件 @
0005f4d1
...
...
@@ -10,117 +10,100 @@
## 一、训练模型转Serving模型
### 检测模型转Serving模型
在前序文档
[
基于Python预测引擎推理
](
./inference.md
)
中,我们提供了如何把训练的checkpoint转换成Paddle模型。Paddle模型通常由一个文件夹构成,内含模型结构描述文件
`model`
和模型参数文件
`params`
。Serving模型由两个文件夹构成,用于存放客户端和服务端的配置。
下载超轻量级中文检测模型
:
我们以
`ch_rec_r34_vd_crnn`
模型作为例子,下载链接在
:
```
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar && tar xf ./ch_lite/ch_det_mv3_db.tar -C ./ch_lite/
wget --no-check-certificate https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_infer.tar
tar xf ch_rec_r34_vd_crnn_infer.tar
```
上述模型是以MobileNetV3为backbone训练的DB算法,将训练好的模型转换成Serving模型只需要运行如下命令:
因此我们按照Serving模型转换教程,运行下列python文件。
```
# -c后面设置训练算法的yml配置文件
# -o配置可选参数
# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。
# Global.save_inference_dir参数设置转换的模型将保存的地址。
python tools/export_serving_model.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./ch_lite/det_mv3_db/best_accuracy Global.save_inference_dir=./inference/det_db/
python tools/inference_to_serving.py --model_dir ch_rec_r34_vd_crnn
```
转Serving模型时,使用的配置文件和训练时使用的配置文件相同。另外,还需要设置配置文件中的
`Global.checkpoints`
、
`Global.save_inference_dir`
参数。 其中
`Global.checkpoints`
指向训练中保存的模型参数文件,
`Global.save_inference_dir`
是生成的inference模型要保存的目录。 转换成功后,在
`save_inference_dir`
目录下有两个文件:
最终会在
`serving_client_dir`
和
`serving_server_dir`
生成客户端和服务端的模型配置。其中
`serving_server_dir`
和
`serving_client_dir`
的名字可以自定义。最终文件结构如下
```
inference/det_db
/
/ch_rec_r34_vd_crnn
/
├── serving_client_dir # 客户端配置文件夹
└── serving_server_dir # 服务端配置文件夹
```
### 识别模型转Serving模型
下载超轻量中文识别模型:
```
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar && tar xf ./ch_lite/ch_rec_mv3_crnn.tar -C ./ch_lite/
```
识别模型转inference模型与检测的方式相同,如下:
## 二、文本检测模型Serving推理
```
# -c后面设置训练算法的yml配置文件
# -o配置可选参数
# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。
# Global.save_inference_dir参数设置转换的模型将保存的地址。
启动服务可以根据实际需求选择启动
`标准版`
或者
`快速版`
,两种方式的对比如下表:
python3 tools/export_serving_model.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints=./ch_lite/rec_mv3_crnn/best_accuracy \
Global.save_inference_dir=./inference/rec_crnn/
```
|版本|特点|适用场景|
|-|-|-|
|标准版|稳定性高,分布式部署|适用于吞吐量大,需要跨机房部署的情况|
|快速版|部署方便,预测速度快|适用于对预测速度要求高,迭代速度快的场景,Windows用户只能选择快速版|
**注意:**
如果您是在自己的数据集上训练的模型,并且调整了中文字符的字典文件,请注意修改配置文件中的
`character_dict_path`
是否是所需要的字典文件。
接下来的命令中,我们会指定快速版和标准版的命令。需要说明的是,标准版只能用Linux平台,快速版可以支持Linux/Windows。
文本检测模型推理,默认使用DB模型的配置参数,识别默认为CRNN。
转换成功后,在目录下有两个文件:
配置文件在
`params.py`
中,我们贴出配置部分,如果需要做改动,也在这个文件内部进行修改。
```
/inference/rec_crnn/
├── serving_client_dir # 客户端配置文件夹
└── serving_server_dir # 服务端配置文件夹
```
def read_params():
cfg = Config()
#use gpu
cfg.use_gpu = False # 是否使用GPU
cfg.use_pdserving = True # 是否使用paddleserving,必须为True
### 方向分类模型转Serving模型
#params for text detector
cfg.det_algorithm = "DB" # 检测算法, DB/EAST等
cfg.det_model_dir = "./det_mv_server/" # 检测算法模型路径
cfg.det_max_side_len = 960
下载方向分类模型:
#DB params
cfg.det_db_thresh =0.3
cfg.det_db_box_thresh =0.5
cfg.det_db_unclip_ratio =2.0
```
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile-v1.1.cls_pre.tar && tar xf ./ch_lite/ch_ppocr_mobile-v1.1.cls_pre.tar -C ./ch_lite/
```
#EAST params
cfg.det_east_score_thresh = 0.8
cfg.det_east_cover_thresh = 0.1
cfg.det_east_nms_thresh = 0.2
方向分类模型转inference模型与检测的方式相同,如下:
#params for text recognizer
cfg.rec_algorithm = "CRNN" # 识别算法, CRNN/RARE等
cfg.rec_model_dir = "./ocr_rec_server/" # 识别算法模型路径
```
# -c后面设置训练算法的yml配置文件
# -o配置可选参数
# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。
# Global.save_inference_dir参数设置转换的模型将保存的地址。
cfg.rec_image_shape = "3, 32, 320"
cfg.rec_char_type = 'ch'
cfg.rec_batch_num = 30
cfg.max_text_length = 25
python3 tools/export_serving_model.py -c configs/cls/cls_mv3.yml -o Global.checkpoints=./ch_lite/cls_model/best_accuracy \
Global.save_inference_dir=./inference/cls/
```
cfg.rec_char_dict_path = "./ppocr_keys_v1.txt" # 识别算法字典文件
cfg.use_space_char = True
转换成功后,在目录下有两个文件:
#params for text classifier
cfg.use_angle_cls = True # 是否启用分类算法
cfg.cls_model_dir = "./ocr_clas_server/" # 分类算法模型路径
cfg.cls_image_shape = "3, 48, 192"
cfg.label_list = ['0', '180']
cfg.cls_batch_num = 30
cfg.cls_thresh = 0.9
return cfg
```
/inference/cls/
├── serving_client_dir # 客户端配置文件夹
└── serving_server_dir # 服务端配置文件夹
```
在接下来的教程中,我们将给出推理的demo模型下载链接。
与本地预测不同的是,Serving预测需要一个客户端和一个服务端,因此接下来的教程都是两行代码。
在正式执行服务端启动命令之前,先export PYTHONPATH到工程主目录下。
```
wget --no-check-certificate https://paddleocr.bj.bcebos.com/deploy/pdserving/ocr_pdserving_suite.tar.gz
tar zxf ocr_pdserving_suite.tar.gz
export PYTHONPATH=$PWD:$PYTHONPATH
cd deploy/pdserving
```
## 二、文本检测模型Serving推理
文本检测模型推理,默认使用DB模型的配置参数。当不使用DB模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。
与本地预测不同的是,Serving预测需要一个客户端和一个服务端,因此接下来的教程都是两行代码。所有的
### 1. 超轻量中文检测模型推理
超轻量中文检测模型推理,可以执行如下命令启动服务端:
```
#根据环境只需要启动其中一个就可以
python det_rpc_server.py
--use_pdserving True --det_model_dir det_mv_server
#标准版,Linux用户
python det_local_server.py
--use_pdserving True --det_model_dir det_mv_server
#快速版,Windows/Linux用户
python det_rpc_server.py #标准版,Linux用户
python det_local_server.py #快速版,Windows/Linux用户
```
如果需要使用CPU版本,还需增加
`--use_gpu False`
。
客户端
...
...
@@ -129,23 +112,8 @@ python det_web_client.py
```
Serving的推测和本地预测不同点在于,客户端发送请求到服务端,服务端需要检测到文字框之后返回框的坐标,此处没有后处理的图片,只能看到坐标值。
### 2. DB文本检测模型推理
首先将DB文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例(
[
模型下载地址
](
https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar
)
),可以使用如下命令进行转换:
```
# -c后面设置训练算法的yml配置文件
# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。
# Global.save_inference_dir参数设置转换的模型将保存的地址。
python3 tools/export_serving_model.py -c configs/det/det_r50_vd_db.yml -o Global.checkpoints="./models/det_r50_vd_db/best_accuracy" Global.save_inference_dir="./inference/det_db"
```
经过转换之后,会在
`./inference/det_db`
目录下出现
`serving_server_dir`
和
`serving_client_dir`
,然后指定
`det_model_dir`
。
## 三、文本识别模型Serving推理
下面将介绍超轻量中文识别模型推理、基于CTC损失的识别模型推理和基于Attention损失的识别模型推理。对于中文文本识别,建议优先选择基于CTC损失的识别模型,实践中也发现基于Attention损失的效果不如基于CTC损失的识别模型。此外,如果训练时修改了文本的字典,请参考下面的自定义文本识别字典的推理。
...
...
@@ -153,11 +121,11 @@ python3 tools/export_serving_model.py -c configs/det/det_r50_vd_db.yml -o Global
### 1. 超轻量中文识别模型推理
超轻量中文识别模型推理,可以执行如下命令启动服务端:
需要注意params.py中的
`--use_gpu`
的值
```
#根据环境只需要启动其中一个就可以
python rec_rpc_server.py
--use_pdserving True --rec_model_dir ocr_rec_server
#标准版,Linux用户
python rec_local_server.py
--use_pdserving True --rec_model_dir ocr_rec_server
#快速版,Windows/Linux用户
python rec_rpc_server.py #标准版,Linux用户
python rec_local_server.py #快速版,Windows/Linux用户
```
如果需要使用CPU版本,还需增加
`--use_gpu False`
。
...
...
@@ -186,13 +154,12 @@ python rec_web_client.py
### 1. 方向分类模型推理
方向分类模型推理, 可以执行如下命令启动服务端:
需要注意params.py中的
`--use_gpu`
的值
```
#根据环境只需要启动其中一个就可以
python clas_rpc_server.py
--use_pdserving True --cls_model_dir ocr_clas_server
#标准版,Linux用户
python clas_local_server.py
--use_pdserving True --cls_model_dir ocr_clas_server
#快速版,Windows/Linux用户
python clas_rpc_server.py #标准版,Linux用户
python clas_local_server.py #快速版,Windows/Linux用户
```
如果需要使用CPU版本,还需增加
`--use_gpu False`
。
客户端
...
...
@@ -216,20 +183,20 @@ python rec_web_client.py
在执行预测时,需要通过参数
`image_dir`
指定单张图像或者图像集合的路径、参数
`det_model_dir`
,
`cls_model_dir`
和
`rec_model_dir`
分别指定检测,方向分类和识别的inference模型路径。参数
`use_angle_cls`
用于控制是否启用方向分类模型。与本地预测不同的是,为了减少网络传输耗时,可视化识别结果目前不做处理,用户收到的是推理得到的文字字段。
执行如下命令启动服务端:
需要注意params.py中的
`--use_gpu`
的值
```
#标准版,Linux用户
#GPU用户
python -m paddle_serving_server_gpu.serve --model det_mv_server --port 9293 --gpu_id 0
python -m paddle_serving_server_gpu.serve --model ocr_cls_server --port 9294 --gpu_id 0
python ocr_rpc_server.py
--use_pdserving True --use_gpu True --rec_model_dir ocr_rec_server
python ocr_rpc_server.py
#CPU用户
python -m paddle_serving_server.serve --model det_mv_server --port 9293
python -m paddle_serving_server.serve --model ocr_cls_server --port 9294
python ocr_rpc_server.py
--use_pdserving True --use_gpu False --rec_model_dir ocr_rec_server
python ocr_rpc_server.py
#快速版,Windows/Linux用户
python ocr_local_server.py
--use_gpu False --use_pdserving True --rec_model_dir ocr_rec_server/ --det_model_dir det_mv_server/ --cls_model_dir ocr_clas_server/ --rec_char_dict_path ppocr_keys_v1.txt --use_angle_cls True
python ocr_local_server.py
```
客户端
...
...
ppocr/data/det/db_process.py
浏览文件 @
0005f4d1
...
...
@@ -21,12 +21,10 @@ from ppocr.utils.utility import initial_logger, check_and_read_gif
logger
=
initial_logger
()
import
tools.infer.utility
as
utility
args
=
utility
.
parse_args
()
if
args
.
use_pdserving
is
False
:
from
.data_augment
import
AugmentData
from
.random_crop_data
import
RandomCropData
from
.make_shrink_map
import
MakeShrinkMap
from
.make_border_map
import
MakeBorderMap
from
.data_augment
import
AugmentData
from
.random_crop_data
import
RandomCropData
from
.make_shrink_map
import
MakeShrinkMap
from
.make_border_map
import
MakeBorderMap
class
DBProcessTrain
(
object
):
...
...
tools/export_serving_model.py
已删除
100644 → 0
浏览文件 @
e53c4273
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
def
set_paddle_flags
(
**
kwargs
):
for
key
,
value
in
kwargs
.
items
():
if
os
.
environ
.
get
(
key
,
None
)
is
None
:
os
.
environ
[
key
]
=
str
(
value
)
# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
# not take any effect.
set_paddle_flags
(
FLAGS_eager_delete_tensor_gb
=
0
,
# enable GC to save memory
)
import
program
from
paddle
import
fluid
from
ppocr.utils.utility
import
initial_logger
logger
=
initial_logger
()
from
ppocr.utils.save_load
import
init_model
from
paddle_serving_client.io
import
save_model
def
main
():
startup_prog
,
eval_program
,
place
,
config
,
_
=
program
.
preprocess
()
feeded_var_names
,
target_vars
,
fetches_var_name
=
program
.
build_export
(
config
,
eval_program
,
startup_prog
)
eval_program
=
eval_program
.
clone
(
for_test
=
True
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
init_model
(
config
,
eval_program
,
exe
)
save_inference_dir
=
config
[
'Global'
][
'save_inference_dir'
]
if
not
os
.
path
.
exists
(
save_inference_dir
):
os
.
makedirs
(
save_inference_dir
)
serving_client_dir
=
"{}/serving_client_dir"
.
format
(
save_inference_dir
)
serving_server_dir
=
"{}/serving_server_dir"
.
format
(
save_inference_dir
)
feed_dict
=
{
x
:
eval_program
.
global_block
().
var
(
x
)
for
x
in
feeded_var_names
}
fetch_dict
=
{
x
.
name
:
x
for
x
in
target_vars
}
save_model
(
serving_server_dir
,
serving_client_dir
,
feed_dict
,
fetch_dict
,
eval_program
)
print
(
"paddle serving model saved in {}/serving_server_dir and {}/serving_client_dir"
.
format
(
save_inference_dir
,
save_inference_dir
))
print
(
"save success, output_name_list:"
,
fetches_var_name
)
if
__name__
==
'__main__'
:
main
()
tools/infer/predict_cls.py
浏览文件 @
0005f4d1
...
...
@@ -36,10 +36,10 @@ class TextClassifier(object):
if
args
.
use_pdserving
is
False
:
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
utility
.
create_predictor
(
args
,
mode
=
"cls"
)
self
.
use_zero_copy_run
=
args
.
use_zero_copy_run
self
.
cls_image_shape
=
[
int
(
v
)
for
v
in
args
.
cls_image_shape
.
split
(
","
)]
self
.
cls_batch_num
=
args
.
rec_batch_num
self
.
label_list
=
args
.
label_list
self
.
use_zero_copy_run
=
args
.
use_zero_copy_run
self
.
cls_thresh
=
args
.
cls_thresh
def
resize_norm_img
(
self
,
img
):
...
...
tools/infer/predict_det.py
浏览文件 @
0005f4d1
...
...
@@ -42,7 +42,6 @@ class TextDetector(object):
def
__init__
(
self
,
args
):
max_side_len
=
args
.
det_max_side_len
self
.
det_algorithm
=
args
.
det_algorithm
self
.
use_zero_copy_run
=
args
.
use_zero_copy_run
preprocess_params
=
{
'max_side_len'
:
max_side_len
}
postprocess_params
=
{}
if
self
.
det_algorithm
==
"DB"
:
...
...
@@ -76,6 +75,7 @@ class TextDetector(object):
logger
.
info
(
"unknown det_algorithm:{}"
.
format
(
self
.
det_algorithm
))
sys
.
exit
(
0
)
if
args
.
use_pdserving
is
False
:
self
.
use_zero_copy_run
=
args
.
use_zero_copy_run
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
utility
.
create_predictor
(
args
,
mode
=
"det"
)
...
...
tools/infer/predict_rec.py
浏览文件 @
0005f4d1
...
...
@@ -37,12 +37,12 @@ class TextRecognizer(object):
if
args
.
use_pdserving
is
False
:
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
utility
.
create_predictor
(
args
,
mode
=
"rec"
)
self
.
use_zero_copy_run
=
args
.
use_zero_copy_run
self
.
rec_image_shape
=
[
int
(
v
)
for
v
in
args
.
rec_image_shape
.
split
(
","
)]
self
.
character_type
=
args
.
rec_char_type
self
.
rec_batch_num
=
args
.
rec_batch_num
self
.
rec_algorithm
=
args
.
rec_algorithm
self
.
text_len
=
args
.
max_text_length
self
.
use_zero_copy_run
=
args
.
use_zero_copy_run
char_ops_params
=
{
"character_type"
:
args
.
rec_char_type
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
...
...
tools/infer/utility.py
浏览文件 @
0005f4d1
...
...
@@ -37,7 +37,6 @@ def parse_args():
parser
.
add_argument
(
"--ir_optim"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--use_tensorrt"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--gpu_mem"
,
type
=
int
,
default
=
8000
)
parser
.
add_argument
(
"--use_pdserving"
,
type
=
str2bool
,
default
=
False
)
# params for text detector
parser
.
add_argument
(
"--image_dir"
,
type
=
str
)
...
...
tools/inference_to_serving.py
0 → 100644
浏览文件 @
0005f4d1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
from
paddle_serving_client.io
import
inference_model_to_serving
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--server_dir"
,
type
=
str
,
default
=
"serving_server_dir"
)
parser
.
add_argument
(
"--client_dir"
,
type
=
str
,
default
=
"serving_client_dir"
)
return
parser
.
parse_args
()
args
=
parse_args
()
inference_model_dir
=
args
.
model_dir
serving_client_dir
=
args
.
server_dir
serving_server_dir
=
args
.
client_dir
feed_var_names
,
fetch_var_names
=
inference_model_to_serving
(
inference_model_dir
,
serving_client_dir
,
serving_server_dir
,
model_filename
=
"model"
,
params_filename
=
"params"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录