Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
4db1820d
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4db1820d
编写于
4月 10, 2020
作者:
D
dyning
提交者:
GitHub
4月 10, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #7 from shippingwang/master
refine inference code
上级
eaf07d64
e6835830
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
346 addition
and
73 deletion
+346
-73
docs/zh_cn/extension/paddle_inference.md
docs/zh_cn/extension/paddle_inference.md
+262
-0
docs/zh_cn/tutorials/getting_started.md
docs/zh_cn/tutorials/getting_started.md
+24
-2
tools/infer/predict.py
tools/infer/predict.py
+60
-21
tools/infer/run.sh
tools/infer/run.sh
+0
-49
tools/infer/utils.py
tools/infer/utils.py
+0
-1
未找到文件。
docs/zh_cn/extension/paddle_inference.md
0 → 100644
浏览文件 @
4db1820d
# 分类预测框架
## 一、简介
Paddle 的模型保存有多种不同的形式,大体可分为两类:
1.
persistable 模型(fluid.save_persistabels保存的模型)
一般做为模型的 checkpoint,可以加载后重新训练。persistable 模型保存的是零散的权重文件,每个文件代表模型中的一个 Variable,这些零散的文件不包含结构信息,需要结合模型的结构一起使用。
```
resnet50-vd-persistable/
├── bn2a_branch1_mean
├── bn2a_branch1_offset
├── bn2a_branch1_scale
├── bn2a_branch1_variance
├── bn2a_branch2a_mean
├── bn2a_branch2a_offset
├── bn2a_branch2a_scale
├── ...
└── res5c_branch2c_weights
```
2.
inference 模型(fluid.io.save_inference_model保存的模型)
一般是模型训练完成后保存的固化模型,用于预测部署。与 persistable 模型相比,inference 模型会额外保存模型的结构信息,用于配合权重文件构成完整的模型。如下所示,
`model`
中保存的即为模型的结构信息。
```
resnet50-vd-persistable/
├── bn2a_branch1_mean
├── bn2a_branch1_offset
├── bn2a_branch1_scale
├── bn2a_branch1_variance
├── bn2a_branch2a_mean
├── bn2a_branch2a_offset
├── bn2a_branch2a_scale
├── ...
├── res5c_branch2c_weights
└── model
```
为了方便起见,paddle 在保存 inference 模型的时候也可以将所有的权重文件保存成一个
`params`
文件,如下所示:
```
resnet50-vd
├── model
└── params
```
在 Paddle 中训练引擎和预测引擎都支持模型的预测推理,只不过预测引擎不需要进行反向操作,因此可以进行定制型的优化(如层融合,kernel 选择等),达到低时延、高吞吐的目的。训练引擎既可以支持 persistable 模型,也可以支持 inference 模型,而预测引擎只支持 inference 模型,因此也就衍生出了三种不同的预测方式:
1.
预测引擎 + inference 模型
2.
训练引擎 + persistable 模型
3.
训练引擎 + inference 模型
不管是何种预测方式,基本都包含以下几个主要的步骤:
+
构建引擎
+
构建待预测数据
+
执行预测
+
预测结果解析
不同预测方式,主要有两方面不同:构建引擎和执行预测,以下的几个部分我们会具体介绍。
## 二、模型转换
在任务的训练阶段,通常我们会保存一些 checkpoint(persistable 模型),这些只是模型权重文件,不能直接被预测引擎直接加载预测,所以我们通常会在训练完之后,找到合适的 checkpoint 并将其转换为 inference 模型。主要分为两个步骤:1. 构建训练引擎,2. 保存 inference 模型,如下所示:
```
python
import
fluid
from
ppcls.modeling.architectures.resnet_vd
import
ResNet50_vd
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
startup_prog
=
fluid
.
Program
()
infer_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
infer_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
image
=
create_input
()
image
=
fluid
.
data
(
name
=
'image'
,
shape
=
[
None
,
3
,
224
,
224
],
dtype
=
'float32'
)
out
=
ResNet50_vd
.
net
(
input
=
input
,
class_dim
=
1000
)
infer_prog
=
infer_prog
.
clone
(
for_test
=
True
)
fluid
.
load
(
program
=
infer_prog
,
model_path
=
persistable
模型路径
,
executor
=
exe
)
fluid
.
io
.
save_inference_model
(
dirname
=
'./output/'
,
feeded_var_names
=
[
image
.
name
],
main_program
=
infer_prog
,
target_vars
=
out
,
executor
=
exe
,
model_filename
=
'model'
,
params_filename
=
'params'
)
```
在模型库的
`tools/export_model.py`
中提供了完整的示例,只需执行下述命令即可完成转换:
```
python
python
tools
/
export_model
.
py
\
--
m
=
模型名称
\
--
p
=
persistable
模型路径
\
--
o
=
model和params保存路径
```
## 三、预测引擎 + inference 模型预测
在模型库的
`tools/predict.py`
中提供了完整的示例,只需执行下述命令即可完成预测:
```
python ./predict.py \
-i=./test.jpeg \
-m=./resnet50-vd/model \
-p=./resnet50-vd/params \
--use_gpu=1 \
--use_tensorrt=True
```
参数说明:
+
`image_file`
(简写 i):待预测的图片文件路径,如
`./test.jpeg`
+
`model_file`
(简写 m):模型文件路径,如
`./resnet50-vd/model`
+
`params_file`
(简写 p):权重文件路径,如
`./resnet50-vd/params`
+
`batch_size`
(简写 b):批大小,如
`1`
+
`ir_optim`
:是否使用
`IR`
优化,默认值:True
+
`use_tensorrt`
:是否使用 TesorRT 预测引擎,默认值:True
+
`gpu_mem`
: 初始分配GPU显存,以M单位
+
`use_gpu`
:是否使用 GPU 预测,默认值:True
+
`enable_benchmark`
:是否启用benchmark,默认值:False
+
`model_name`
:模型名字
注意:
当启用benchmark时,默认开启tersorrt进行预测
构建预测引擎:
```
python
from
paddle.fluid.core
import
AnalysisConfig
from
paddle.fluid.core
import
create_paddle_predictor
config
=
AnalysisConfig
(
model文件路径
,
params文件路径
)
config
.
enable_use_gpu
(
8000
,
0
)
config
.
disable_glog_info
()
config
.
switch_ir_optim
(
True
)
config
.
enable_tensorrt_engine
(
precision_mode
=
AnalysisConfig
.
Precision
.
Float32
,
max_batch_size
=
1
)
# no zero copy方式需要去除fetch feed op
config
.
switch_use_feed_fetch_ops
(
False
)
predictor
=
create_paddle_predictor
(
config
)
```
执行预测:
```
python
import
numpy
as
np
input_names
=
predictor
.
get_input_names
()
input_tensor
=
predictor
.
get_input_tensor
(
input_names
[
0
])
input
=
np
.
random
.
randn
(
1
,
3
,
224
,
224
).
astype
(
"float32"
)
input_tensor
.
reshape
([
1
,
3
,
224
,
224
])
input_tensor
.
copy_from_cpu
(
input
)
predictor
.
zero_copy_run
()
```
更多预测参数说明可以参考官网
[
Paddle Python 预测 API
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/python_infer_cn.html
)
。如果需要在业务的生产环境部署,也推荐使用
[
Paddel C++ 预测 API
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/native_infer.html
)
,官网提供了丰富的预编译预测库
[
Paddle C++ 预测库
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html
)
。
默认情况下,Paddle 的 wheel 包中是不包含 TensorRT 预测引擎的,如果需要使用 TensorRT 进行预测优化,需要自己编译对应的 wheel 包,编译方式可以参考 Paddle 的编译指南
[
Paddle 编译
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/install/compile/fromsource.html
)
。
## 四、训练引擎 + persistable 模型预测
在模型库的
`tools/infer.py`
中提供了完整的示例,只需执行下述命令即可完成预测:
```
python
python
tools
/
infer
.
py
\
--
i
=
待预测的图片文件路径
\
--
m
=
模型名称
\
--
p
=
persistable
模型路径
\
--
use_gpu
=
True
```
参数说明:
+
`image_file`
(简写 i):待预测的图片文件路径,如
`./test.jpeg`
+
`model_file`
(简写 m):模型文件路径,如
`./resnet50-vd/model`
+
`params_file`
(简写 p):权重文件路径,如
`./resnet50-vd/params`
+
`use_gpu`
: 是否开启GPU训练,默认值:True
训练引擎构建:
由于 persistable 模型不包含模型的结构信息,因此需要先构建出网络结构,然后 load 权重来构建训练引擎。
```
python
import
fluid
from
ppcls.modeling.architectures.resnet_vd
import
ResNet50_vd
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
startup_prog
=
fluid
.
Program
()
infer_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
infer_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
image
=
create_input
()
image
=
fluid
.
data
(
name
=
'image'
,
shape
=
[
None
,
3
,
224
,
224
],
dtype
=
'float32'
)
out
=
ResNet50_vd
.
net
(
input
=
input
,
class_dim
=
1000
)
infer_prog
=
infer_prog
.
clone
(
for_test
=
True
)
fluid
.
load
(
program
=
infer_prog
,
model_path
=
persistable
模型路径
,
executor
=
exe
)
```
执行预测:
```
python
outputs
=
exe
.
run
(
infer_prog
,
feed
=
{
image
.
name
:
data
},
fetch_list
=
[
out
.
name
],
return_numpy
=
False
)
```
上述执行预测时候的参数说明可以参考官网
[
fluid.Executor
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/executor_cn/Executor_cn.html
)
## 五、训练引擎 + inference 模型预测
在模型库的
`tools/py_infer.py`
中提供了完整的示例,只需执行下述命令即可完成预测:
```
python
python
tools
/
py_infer
.
py
\
--
i
=
图片路径
\
--
d
=
模型的存储路径
\
--
m
=
保存的模型文件
\
--
p
=
保存的参数文件
\
--
use_gpu
=
True
```
+
`image_file`
(简写 i):待预测的图片文件路径,如
`./test.jpeg`
+
`model_file`
(简写 m):模型文件路径,如
`./resnet50_vd/model`
+
`params_file`
(简写 p):权重文件路径,如
`./resnet50_vd/params`
+
`model_dir`
(简写d):模型路径,如
`./resent50_vd`
+
`use_gpu`
:是否开启GPU,默认值:True
训练引擎构建:
由于 inference 模型已包含模型的结构信息,因此不再需要提前构建模型结构,直接 load 模型结构和权重文件来构建训练引擎。
```
python
import
fluid
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
[
program
,
feed_names
,
fetch_lists
]
=
fluid
.
io
.
load_inference_model
(
模型的存储路径
,
exe
,
model_filename
=
保存的模型文件
,
params_filename
=
保存的参数文件
)
compiled_program
=
fluid
.
compiler
.
CompiledProgram
(
program
)
```
> `load_inference_model` 既支持零散的权重文件集合,也支持融合后的单个权重文件。
执行预测:
```
python
outputs
=
exe
.
run
(
compiled_program
,
feed
=
{
feed_names
[
0
]:
data
},
fetch_list
=
fetch_lists
,
return_numpy
=
False
)
```
上述执行预测时候的参数说明可以参考官网
[
fluid.Executor
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/executor_cn/Executor_cn.html
)
docs/zh_cn/tutorials/getting_started.md
浏览文件 @
4db1820d
...
...
@@ -2,7 +2,7 @@
---
请事先参考
[
安装指南
](
install.md
)
配置运行环境
##
1
设置环境变量
##
一、
设置环境变量
**设置PYTHONPATH环境变量:**
...
...
@@ -10,7 +10,7 @@
export
PYTHONPATH
=
path_to_PaddleClas:
$PYTHONPATH
```
##
2
模型训练与评估
##
二、
模型训练与评估
PaddleClas 提供模型训练与评估脚本:tools/train.py和tools/eval.py
...
...
@@ -62,3 +62,25 @@ python eval.py \
-o
pretrained_model
=
path_to_pretrained_models
```
您可以更改configs/eval.yaml中的architecture字段和pretrained_model字段来配置评估模型,或是通过-o参数更新配置。
## 三、模型推理
PaddlePaddle提供三种方式进行预测推理,接下来介绍如何用预测引擎进行推理:
首先,对训练好的模型进行转换
```
bash
python tools/export_model.py
\
-model
=
模型名字
\
-pretrained_model
=
预训练模型路径
\
-output_path
=
预测模型保存路径
```
之后,通过预测引擎进行推理
```
bash
python tools/predict.py
\
-m
model文件路径
\
-p
params文件路径
\
-i
图片路径
\
--use_gpu
=
1
\
--use_tensorrt
=
True
```
更多使用方法和推理方式请参考
[
分类预测框架
](
../extension/paddle_inference.md
)
tools/infer/
cpp_infer
.py
→
tools/infer/
predict
.py
浏览文件 @
4db1820d
...
...
@@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
utils
import
argparse
import
numpy
as
np
import
logging
import
time
from
paddle.fluid.core
import
PaddleTensor
from
paddle.fluid.core
import
AnalysisConfig
from
paddle.fluid.core
import
create_paddle_predictor
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logger
=
logging
.
getLogger
(
__name__
)
def
parse_args
():
def
str2bool
(
v
):
...
...
@@ -29,26 +32,38 @@ def parse_args():
parser
.
add_argument
(
"-i"
,
"--image_file"
,
type
=
str
)
parser
.
add_argument
(
"-m"
,
"--model_file"
,
type
=
str
)
parser
.
add_argument
(
"-p"
,
"--params_file"
,
type
=
str
)
parser
.
add_argument
(
"-b"
,
"--max_batch_size"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"-b"
,
"--batch_size"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--use_fp16"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--use_gpu"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--ir_optim"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--use_tensorrt"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--gpu_mem"
,
type
=
int
,
default
=
8000
)
parser
.
add_argument
(
"--enable_benchmark"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--model_name"
,
type
=
str
)
return
parser
.
parse_args
()
def
create_predictor
(
args
):
config
=
AnalysisConfig
(
args
.
model_file
,
args
.
params_file
)
if
args
.
use_gpu
:
config
.
enable_use_gpu
(
1000
,
0
)
config
.
enable_use_gpu
(
args
.
gpu_mem
,
0
)
else
:
config
.
disable_gpu
()
config
.
switch_ir_optim
(
args
.
ir_optim
)
# default true
config
.
disable_glog_info
()
config
.
switch_ir_optim
(
args
.
ir_optim
)
# default true
if
args
.
use_tensorrt
:
config
.
enable_tensorrt_engine
(
precision_mode
=
AnalysisConfig
.
Precision
.
Float32
,
max_batch_size
=
args
.
max_batch_size
)
precision_mode
=
AnalysisConfig
.
Precision
.
Half
if
args
.
use_fp16
else
AnalysisConfig
.
Precision
.
Float32
,
max_batch_size
=
args
.
batch_size
)
config
.
enable_memory_optim
()
# use zero copy
config
.
switch_use_feed_fetch_ops
(
False
)
predictor
=
create_paddle_predictor
(
config
)
return
predictor
...
...
@@ -64,7 +79,7 @@ def create_operators():
resize_op
=
utils
.
ResizeImage
(
resize_short
=
256
)
crop_op
=
utils
.
CropImage
(
size
=
(
size
,
size
))
normalize_op
=
utils
.
NormalizeImage
(
scale
=
img_scale
,
mean
=
img_mean
,
std
=
img_std
)
scale
=
img_scale
,
mean
=
img_mean
,
std
=
img_std
)
totensor_op
=
utils
.
ToTensor
()
return
[
decode_op
,
resize_op
,
crop_op
,
normalize_op
,
totensor_op
]
...
...
@@ -78,25 +93,49 @@ def preprocess(fname, ops):
return
data
def
postprocess
(
outputs
,
topk
=
5
):
output
=
outputs
[
0
]
prob
=
output
.
as_ndarray
().
flatten
()
index
=
prob
.
argsort
(
axis
=
0
)[
-
topk
:][::
-
1
].
astype
(
'int32'
)
return
zip
(
index
,
prob
[
index
])
def
main
():
args
=
parse_args
()
if
not
args
.
enable_benchmark
:
assert
args
.
batch_size
==
1
assert
args
.
use_fp16
==
False
else
:
assert
args
.
use_gpu
==
True
assert
args
.
model_name
is
not
None
assert
args
.
use_tensorrt
==
True
# HALF precission predict only work when using tensorrt
if
args
.
use_fp16
==
True
:
assert
args
.
use_tensorrt
==
True
operators
=
create_operators
()
predictor
=
create_predictor
(
args
)
data
=
preprocess
(
args
.
image_file
,
operators
)
inputs
=
[
PaddleTensor
(
data
.
copy
())]
outputs
=
predictor
.
run
(
inputs
)
probs
=
postprocess
(
outputs
)
inputs
=
preprocess
(
args
.
image_file
,
operators
)
inputs
=
np
.
expand_dims
(
inputs
,
axis
=
0
).
repeat
(
args
.
batch_size
,
axis
=
0
).
copy
()
for
idx
,
prob
in
probs
:
print
(
"class id: {:d}, probability: {:.4f}"
.
format
(
idx
,
prob
))
input_names
=
predictor
.
get_input_names
()
input_tensor
=
predictor
.
get_input_tensor
(
input_names
[
0
])
input_tensor
.
copy_from_cpu
(
inputs
)
if
not
args
.
enable_benchmark
:
predictor
.
zero_copy_run
()
else
:
for
i
in
range
(
0
,
1010
):
if
i
==
10
:
start
=
time
.
time
()
predictor
.
zero_copy_run
()
end
=
time
.
time
()
fp_message
=
"FP16"
if
args
.
use_fp16
else
"FP32"
logger
.
info
(
"{0}
\t
{1}
\t
batch size: {2}
\t
time(ms): {3}"
.
format
(
args
.
model_name
,
fp_message
,
args
.
batch_size
,
end
-
start
))
output_names
=
predictor
.
get_output_names
()
output_tensor
=
predictor
.
get_output_tensor
(
output_names
[
0
])
output
=
output_tensor
.
copy_to_cpu
()
output
=
output
.
flatten
()
cls
=
np
.
argmax
(
output
)
score
=
output
[
cls
]
logger
.
info
(
"class: {0}"
.
format
(
cls
))
logger
.
info
(
"score: {0}"
.
format
(
score
))
if
__name__
==
"__main__"
:
...
...
tools/infer/run.sh
已删除
100644 → 0
浏览文件 @
eaf07d64
#!/usr/bin/env bash
python ./cpp_infer.py
\
-i
=
./test.jpeg
\
-m
=
./resnet50-vd/model
\
-p
=
./resnet50-vd/params
\
--use_gpu
=
1
python ./cpp_infer.py
\
-i
=
./test.jpeg
\
-m
=
./resnet50-vd/model
\
-p
=
./resnet50-vd/params
\
--use_gpu
=
0
python py_infer.py
\
-i
=
./test.jpeg
\
-d
./resnet50-vd/
\
-m
=
model
-p
=
params
\
--use_gpu
=
0
python py_infer.py
\
-i
=
./test.jpeg
\
-d
./resnet50-vd/
\
-m
=
model
-p
=
params
\
--use_gpu
=
1
python infer.py
\
-i
=
./test.jpeg
\
-m
ResNet50_vd
\
-p
./resnet50-vd-persistable/
\
--use_gpu
=
0
python infer.py
\
-i
=
./test.jpeg
\
-m
ResNet50_vd
\
-p
./resnet50-vd-persistable/
\
--use_gpu
=
1
python export_model.py
\
-m
ResNet50_vd
\
-p
./resnet50-vd-persistable/
\
-o
./test/
python py_infer.py
\
-i
=
./test.jpeg
\
-d
./test/
\
-m
=
model
\
-p
=
params
\
--use_gpu
=
0
tools/infer/utils.py
浏览文件 @
4db1820d
...
...
@@ -81,5 +81,4 @@ class ToTensor(object):
def
__call__
(
self
,
img
):
img
=
img
.
transpose
((
2
,
0
,
1
))
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
return
img
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录