提交 4e794c67 编写于 作者: S sjtubinlong

update python inferences codes and docs

上级 3e19600f
......@@ -68,7 +68,8 @@ $ pip install -r requirements.txt
### 预测部署
* [模型导出](./docs/model_export.md)
* [C++预测库使用](./inference)
* [使用Python预测](./deploy/python/)
* [使用C++预测](./deploy/cpp/)
### 高级功能
......
# PaddleSeg 预测部署
`PaddleSeg`目前支持使用`Python``C++`部署在`Windows``Linux` 上, 也可以集成`PaddleServing`服务化部署在 `Linux` 上。
[1. Python预测(支持 Linux 和 Windows)](./python/)
[2. C++预测(支持 Linux 和 Windows)](./cpp/)
[3. 服务化部署(仅支持 Linux)](./serving)
......@@ -11,7 +11,7 @@
## 1.说明
本目录提供一个跨平台的图像分割模型的C++、Python预测部署方案,用户通过一定的配置,加上少量的代码,即可把模型集成到自己的服务中,完成图像分割的任务。
本目录提供一个跨平台`PaddlePaddle`图像分割模型的`C++`预测部署方案,用户通过一定的配置,加上少量的代码,即可把模型集成到自己的服务中,完成图像分割的任务。
主要设计的目标包括以下四点:
- 跨平台,支持在 windows 和 Linux 完成编译、开发和部署
......@@ -19,6 +19,8 @@
- 可扩展性,支持用户针对新模型开发自己特殊的数据预处理、后处理等逻辑
- 高性能,除了`PaddlePaddle`自身带来的性能优势,我们还针对图像分割的特点对关键步骤进行了性能优化
**注意** 如需要使用`Python`的预测部署方法,请参考:[Python预测部署](../python/)
## 2.主要目录和文件
......@@ -57,8 +59,6 @@ inference
`Windows`上推荐使用最新的`Visual Studio 2019 Community`直接编译`CMake`项目。
针对Python的预测部署方法,可参考以下链接:[Python预测部署方法](python_inference.md)
## 4.预测并可视化结果
完成编译后,便生成了需要的可执行文件和链接库,然后执行以下步骤:
......
# PaddleSeg Python 预测部署方案
本文档旨在提供一个`PaddlePaddle`跨平台图像分割模型的`Python`预测部署方案,用户通过一定的配置,加上少量的代码,即可把模型集成到自己的服务中,完成图像分割的任务。
## 前置条件
* Python2.7+,Python3+
* pip,pip3
## 主要目录和文件
```
├── infer.py # 核心代码,完成分割模型的预测以及结果可视化
├── requirements.txt # 依赖的Python包
└── README.md # 说明文档
```
### Step1:安装PaddlePaddle
如何选择合适版本的`PaddlePaddle`版本进行安装,可参考: [PaddlePaddle安装教程](https://www.paddlepaddle.org.cn/install/doc/)
### Step2:安装Python依赖包
2.1 在**当前**目录下, 使用`pip`安装`Python`依赖包
```bash
pip install -r requirements.txt
```
2.2 安装`OpenCV` 相关依赖库
预测代码中需要使用`OpenCV`,所以还需要`OpenCV`安装相关的动态链接库。
`Ubuntu``CentOS` 为例,命令如下:
`Ubuntu`下安装相关链接库:
```bash
apt-get install -y libglib2.0-0 libsm6 libxext6 libxrender-dev
```
CentOS 下安装相关链接库:
```bash
yum install -y libXext libSM libXrender
```
### Step3:预测
进行预测前, 请使用[模型导出工具](../../docs/model_export.md) 导出您的模型(或点击下载我们的[人像分割样例模型](https://bj.bcebos.com/paddleseg/inference/human_freeze_model.zip)用于测试)。
导出的模型目录通常包括三个文件,除了模型文件`models` 和参数文件`params`,还会生成对应的配置文件`deploy.yaml`用于`C++``Python` 预测, 主要字段及其含义如下:
```yaml
DEPLOY:
# 是否使用GPU预测
USE_GPU: 1
# 模型和参数文件所在目录路径
MODEL_PATH: "/root/projects/models/deeplabv3p_xception65_humanseg"
# 模型文件名
MODEL_FILENAME: "__model__"
# 参数文件名
PARAMS_FILENAME: "__params__"
# 预测图片的的标准输入尺寸,输入尺寸不一致会做resize
EVAL_CROP_SIZE: (513, 513)
# 均值
MEAN: [0.5, 0.5, 0.5]
# 方差
STD: [0.5, 0.5, 0.5]
# 分类类型数
NUM_CLASSES: 2
# 图片通道数
CHANNELS : 3
# 预测模式,支持 NATIVE 和 ANALYSIS
PREDICTOR_MODE: "ANALYSIS"
# 每次预测的 batch_size
BATCH_SIZE : 3
```
模型文件就绪后,在终端输入以下命令进行预测:
```
python infer.py --conf=/path/to/deploy.yaml --input_dir/path/to/images_directory --use_pr=False
```
参数说明如下:
| 参数 | 是否必须|含义 |
|-------|-------|----------|
| conf | YES|模型配置的Yaml文件路径 |
| input_dir |YES| 需要预测的图片目录 |
| use_pr |NO|是否使用优化模型,默认为False|
* 优化模型:使用`PaddleSeg 0.3.0`版导出的为优化模型, 此前版本导出的模型即为未优化版本。优化模型把图像的预处理以及后处理部分融入到模型网络中使用`GPU` 完成,相比原来`CPU` 中的处理提升了计算性能。
运行后会扫描`input_dir` 目录下所有指定格式图片,生成`预测mask``可视化的结果`
对于图片`a.jpeg`, `预测mask` 存在`a_jpeg.png` 中,而可视化结果则在`a_jpeg_result.png` 中。
输入样例:
![avatar](../cpp/images/humanseg/demo2.jpeg)
输出结果:
![avatar](../cpp/images/humanseg/demo2.jpeg_result.png)
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import ast
import time
import gflags
import yaml
import cv2
import numpy as np
import paddle.fluid as fluid
from concurrent.futures import ThreadPoolExecutor, as_completed
gflags.DEFINE_string("conf", default="", help="Configuration File Path")
gflags.DEFINE_string("input_dir", default="", help="Directory of Input Images")
gflags.DEFINE_boolean("use_pr", default=False, help="Use optimized model")
gflags.DEFINE_string("trt_mode", default="", help="Use optimized model")
gflags.FLAGS = gflags.FLAGS
# ColorMap for visualization
color_map = [[128, 64, 128], [244, 35, 231], [69, 69, 69], [102, 102, 156],
[190, 153, 153], [153, 153, 153], [250, 170, 29], [219, 219, 0],
[106, 142, 35], [152, 250, 152], [69, 129, 180], [219, 19, 60],
[255, 0, 0], [0, 0, 142], [0, 0, 69], [0, 60, 100], [0, 79, 100],
[0, 0, 230], [119, 10, 32]]
# Paddle-TRT Precision Map
trt_precision_map = {
"int8": fluid.core.AnalysisConfig.Precision.Int8,
"fp32": fluid.core.AnalysisConfig.Precision.Float32,
"fp16": fluid.core.AnalysisConfig.Precision.Half
}
# scan a directory and get all images with support extensions
def get_images_from_dir(img_dir, support_ext=".jpg|.jpeg"):
if (not os.path.exists(img_dir) or not os.path.isdir(img_dir)):
raise Exception("Image Directory [%s] invalid" % img_dir)
imgs = []
for item in os.listdir(img_dir):
ext = os.path.splitext(item)[1][1:].strip().lower()
if (len(ext) > 0 and ext in support_ext):
item_path = os.path.join(img_dir, item)
imgs.append(item_path)
return imgs
# Deploy Configuration File Parser
class DeployConfig:
def __init__(self, conf_file):
if not os.path.exists(conf_file):
raise Exception('Config file path [%s] invalid!' % conf_file)
with open(conf_file) as fp:
configs = yaml.load(fp, Loader=yaml.FullLoader)
deploy_conf = configs["DEPLOY"]
# 1. get eval_crop_size
self.eval_crop_size = ast.literal_eval(deploy_conf["EVAL_CROP_SIZE"])
# 2. get mean
self.mean = deploy_conf["MEAN"]
# 3. get std
self.std = deploy_conf["STD"]
# 4. get class_num
self.class_num = deploy_conf["NUM_CLASSES"]
# 5. get paddle model and params file path
self.model_file = os.path.join(
deploy_conf["MODEL_PATH"], deploy_conf["MODEL_FILENAME"])
self.param_file = os.path.join(
deploy_conf["MODEL_PATH"], deploy_conf["PARAMS_FILENAME"])
# 6. use_gpu
self.use_gpu = deploy_conf["USE_GPU"]
# 7. predictor_mode
self.predictor_mode = deploy_conf["PREDICTOR_MODE"]
# 8. batch_size
self.batch_size = deploy_conf["BATCH_SIZE"]
# 9. channels
self.channels = deploy_conf["CHANNELS"]
class ImageReader:
def __init__(self, configs):
self.config = configs
self.threads_pool = ThreadPoolExecutor(configs.batch_size)
# image processing thread worker
def process_worker(self, imgs, idx, use_pr=False):
image_path = imgs[idx]
im = cv2.imread(image_path, -1)
channels = im.shape[2]
ori_h = im.shape[0]
ori_w = im.shape[1]
if channels == 1:
im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
channels = im.shape[2]
if channels != 3 and channels != 4:
print("Only support rgb(gray) or rgba image.")
return -1
# resize to eval_crop_size
eval_crop_size = self.config.eval_crop_size
if (ori_h != eval_crop_size[0] or ori_w != eval_crop_size[1]):
im = cv2.resize(
im, eval_crop_size, fx=0, fy=0, interpolation=cv2.INTER_LINEAR)
# if use models with no pre-processing/post-processing op optimizations
if not use_pr:
im_mean = np.array(self.config.mean).reshape((3, 1, 1))
im_std = np.array(self.config.std).reshape((3, 1, 1))
# HWC -> CHW, don't use transpose((2, 0, 1))
im = im.swapaxes(1, 2)
im = im.swapaxes(0, 1)
im = im[:, :, :].astype('float32') / 255.0
im -= im_mean
im /= im_std
im = im[np.newaxis,:,:,:]
info = [image_path, im, (ori_w, ori_h)]
return info
# process multiple images with multithreading
def process(self, imgs, use_pr=False):
imgs_data = []
with ThreadPoolExecutor(max_workers=self.config.batch_size) as exec:
tasks = [exec.submit(self.process_worker, imgs, idx, use_pr)
for idx in range(len(imgs))]
for task in as_completed(tasks):
imgs_data.append(task.result())
return imgs_data
class Predictor:
def __init__(self, conf_file):
self.config = DeployConfig(conf_file)
self.image_reader = ImageReader(self.config)
if self.config.predictor_mode == "NATIVE":
predictor_config = fluid.core.NativeConfig()
predictor_config.prog_file = self.config.model_file
predictor_config.param_file = self.config.param_file
predictor_config.use_gpu = config.use_gpu
predictor_config.device = 0
predictor_config.fraction_of_gpu_memory = 0
elif self.config.predictor_mode == "ANALYSIS":
predictor_config = fluid.core.AnalysisConfig(
self.config.model_file, self.config.param_file)
if self.config.use_gpu:
predictor_config.enable_use_gpu(100, 0)
predictor_config.switch_ir_optim(True)
if gflags.FLAGS.trt_mode != "":
precision_type = trt_precision_map[gflags.FLAGS.trt_mode]
use_calib = (gflags.FLAGS.trt_mode == "int8")
predictor_config.enable_tensorrt_engine(
workspace_size=1<<30,
max_batch_size=self.config.batch_size,
min_subgraph_size=40,
precision_mode=precision_type,
use_static=False,
use_calib_mode=use_calib)
else:
predictor_config.disable_gpu()
predictor_config.switch_specify_input_names(True)
predictor_config.enable_memory_optim()
self.predictor = fluid.core.create_paddle_predictor(predictor_config)
def create_tensor(self, inputs, batch_size, use_pr=False):
im_tensor = fluid.core.PaddleTensor()
im_tensor.name = "image"
if not use_pr:
im_tensor.shape = [batch_size,
self.config.channels,
self.config.eval_crop_size[1],
self.config.eval_crop_size[0]]
else:
im_tensor.shape = [batch_size,
self.config.eval_crop_size[1],
self.config.eval_crop_size[0],
self.config.channels]
im_tensor.dtype = fluid.core.PaddleDType.FLOAT32
im_tensor.data = fluid.core.PaddleBuf(inputs.ravel().astype("float32"))
return [im_tensor]
# save prediction results and visualization them
def output_result(self, imgs_data, infer_out, use_pr=False):
for idx in range(len(imgs_data)):
img_name = imgs_data[idx][0]
ori_shape = imgs_data[idx][2]
mask = infer_out[idx]
if not use_pr:
mask = np.argmax(mask, axis=0)
mask = mask.astype('uint8')
mask_png = mask
score_png = mask_png[:, :, np.newaxis]
score_png = np.concatenate([score_png] * 3, axis=2)
# visualization score png
for i in range(score_png.shape[0]):
for j in range(score_png.shape[1]):
score_png[i, j] = color_map[score_png[i, j, 0]]
# save the mask
# mask of xxx.jpeg will be saved as xxx_jpeg_mask.png
ext_pos = img_name.rfind(".")
img_name_fix = img_name[:ext_pos] + "_" + img_name[ext_pos + 1:]
mask_save_name = img_name_fix + "_mask.png"
cv2.imwrite(mask_save_name, mask_png, [cv2.CV_8UC1])
# save the visualized result
# result of xxx.jpeg will be saved as xxx_jpeg_result.png
vis_result_name = img_name_fix + "_result.png"
result_png = score_png
# if not use_pr:
result_png = cv2.resize(result_png, ori_shape, fx=0, fy=0,
interpolation=cv2.INTER_CUBIC)
cv2.imwrite(vis_result_name, result_png, [cv2.CV_8UC1])
print("save result of [" + img_name + "] done.")
def predict(self, images):
# image reader preprocessing time cost
reader_time = 0
# inference time cost
infer_time = 0
# post_processing: generate mask and visualize it
post_time = 0
# total time cost: preprocessing + inference + postprocessing
total_runtime = 0
# record starting time point
total_start = time.time()
batch_size = self.config.batch_size
for i in range(0, len(images), batch_size):
real_batch_size = batch_size
if i + batch_size >= len(images):
real_batch_size = len(images) - i
reader_start = time.time()
img_datas = self.image_reader.process(images[i: i + real_batch_size])
input_data = np.concatenate([item[1] for item in img_datas])
input_data = self.create_tensor(
input_data, real_batch_size, use_pr=gflags.FLAGS.use_pr)
reader_end = time.time()
infer_start = time.time()
output_data = self.predictor.run(input_data)[0]
infer_end = time.time()
reader_time += (reader_end - reader_start)
infer_time += (infer_end - infer_start)
output_data = output_data.as_ndarray()
post_start = time.time()
self.output_result(img_datas, output_data, gflags.FLAGS.use_pr)
post_end = time.time()
post_time += (post_end - post_start)
# finishing process all images
total_end = time.time()
# compute whole processing time
total_runtime = (total_end - total_start)
print("images_num=[%d],preprocessing_time=[%f],infer_time=[%f],postprocessing_time=[%f],total_runtime=[%f]"
% (len(images), reader_time, infer_time, post_time, total_runtime))
def run(deploy_conf, imgs_dir, support_extensions=".jpg|.jpeg"):
# 1. scan and get all images with valid extensions in directory imgs_dir
imgs = get_images_from_dir(imgs_dir)
if len(imgs) == 0:
print("No Image (with extensions : %s) found in [%s]"
% (support_extensions, imgs_dir))
return -1
# 2. create a predictor
seg_predictor = Predictor(deploy_conf)
# 3. do a inference on images
seg_predictor.predict(imgs)
return 0
if __name__ == "__main__":
# 0. parse the arguments
gflags.FLAGS(sys.argv)
if (gflags.FLAGS.conf == "" or gflags.FLAGS.input_dir == ""):
print("Usage: python infer.py --conf=/config/path/to/your/model "
+"--input_dir=/directory/of/your/input/images [--use_pr=True]")
exit(-1)
# set empty to turn off as default
trt_mode = gflags.FLAGS.trt_mode
if (trt_mode != "" and trt_mode not in trt_precision_map):
print("Invalid trt_mode [%s], only support[int8, fp16, fp32]" % trt_mode)
exit(-1)
# run inference
run(gflags.FLAGS.conf, gflags.FLAGS.input_dir)
......@@ -18,4 +18,4 @@
python pdseg/export_model.py --cfg configs/unet_pet.yaml TEST.TEST_MODEL test/saved_models/unet_pet/final
```
预测模型会导出到`freeze_model`目录,用于C++预测的模型配置会导出到`freeze_model/deploy.yaml`
预测模型会导出到`freeze_model`目录,用于`C++`或者`Python`预测的模型配置会导出到`freeze_model/deploy.yaml`
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import gflags
import yaml
import numpy as np
import cv2
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import NativeConfig
from paddle.fluid.core import create_paddle_predictor
from paddle.fluid.core import PaddleTensor
from paddle.fluid.core import PaddleBuf
from paddle.fluid.core import PaddleDType
from concurrent.futures import ThreadPoolExecutor, as_completed
gflags.DEFINE_string("conf", default="", help="Configuration File Path")
gflags.DEFINE_string("input_dir", default="", help="Directory of Input Images")
gflags.DEFINE_boolean("use_pr", default=False, help="Use optimized model")
Flags = gflags.FLAGS
# ColorMap for visualization more clearly
color_map = [[128, 64, 128], [244, 35, 231], [69, 69, 69], [102, 102, 156],
[190, 153, 153], [153, 153, 153], [250, 170, 29], [219, 219, 0],
[106, 142, 35], [152, 250, 152], [69, 129, 180], [219, 19, 60],
[255, 0, 0], [0, 0, 142], [0, 0, 69], [0, 60, 100], [0, 79, 100],
[0, 0, 230], [119, 10, 32]]
class ConfItemsNotFoundError(Exception):
def __init__(self, message):
super().__init__(message + " item not Found")
class Config:
def __init__(self, config_dict):
if "DEPLOY" not in config_dict:
raise ConfItemsNotFoundError("DEPLOY")
deploy_dict = config_dict["DEPLOY"]
if "EVAL_CROP_SIZE" not in deploy_dict:
raise ConfItemsNotFoundError("EVAL_CROP_SIZE")
# 1. get resize
self.resize = [int(value) for value in
deploy_dict["EVAL_CROP_SIZE"].strip("()").split(",")]
# 2. get mean
if "MEAN" not in deploy_dict:
raise ConfItemsNotFoundError("MEAN")
self.mean = deploy_dict["MEAN"]
# 3. get std
if "STD" not in deploy_dict:
raise ConfItemsNotFoundError("STD")
self.std = deploy_dict["STD"]
# 4. get image type
if "IMAGE_TYPE" not in deploy_dict:
raise ConfItemsNotFoundError("IMAGE_TYPE")
self.img_type = deploy_dict["IMAGE_TYPE"]
# 5. get class number
if "NUM_CLASSES" not in deploy_dict:
raise ConfItemsNotFoundError("NUM_CLASSES")
self.class_num = deploy_dict["NUM_CLASSES"]
# 7. set model path
if "MODEL_PATH" not in deploy_dict:
raise ConfItemsNotFoundError("MODEL_PATH")
self.model_path = deploy_dict["MODEL_PATH"]
# 8. get model file_name
if "MODEL_FILENAME" not in deploy_dict:
self.model_file_name = "__model__"
else:
self.model_file_name = deploy_dict["MODEL_FILENAME"]
# 9. get model param file name
if "PARAMS_FILENAME" not in deploy_dict:
self.param_file_name = "__params__"
else:
self.param_file_name = deploy_dict["PARAMS_FILENAME"]
# 10. get pre_processor
if "PRE_PROCESSOR" not in deploy_dict:
raise ConfItemsNotFoundError("PRE_PROCESSOR")
self.pre_processor = deploy_dict["PRE_PROCESSOR"]
# 11. use_gpu
if "USE_GPU" not in deploy_dict:
self.use_gpu = 0
else:
self.use_gpu = deploy_dict["USE_GPU"]
# 12. predictor_mode
if "PREDICTOR_MODE" not in deploy_dict:
raise ConfItemsNotFoundError("PREDICTOR_MODE")
self.predictor_mode = deploy_dict["PREDICTOR_MODE"]
# 13. batch_size
if "BATCH_SIZE" not in deploy_dict:
raise ConfItemsNotFoundError("BATCH_SIZE")
self.batch_size = deploy_dict["BATCH_SIZE"]
# 14. channels
if "CHANNELS" not in deploy_dict:
raise ConfItemsNotFoundError("CHANNELS")
self.channels = deploy_dict["CHANNELS"]
class PreProcessor:
def __init__(self, config):
self.resize_size = (config.resize[0], config.resize[1])
self.mean = config.mean
self.std = config.std
def process(self, image_file, im_list, ori_h_list, ori_w_list, idx, use_pr=False):
start = time.time()
im = cv2.imread(image_file, -1)
end = time.time()
print("imread spent %fs" % (end - start))
channels = im.shape[2]
ori_h = im.shape[0]
ori_w = im.shape[1]
if channels == 1:
im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
channels = im.shape[2]
if channels != 3 and channels != 4:
print("Only support rgb(gray) or rgba image.")
return -1
if ori_h != self.resize_size[0] or ori_w != self.resize_size[1]:
start = time.time()
im = cv2.resize(im, self.resize_size, fx=0, fy=0, interpolation=cv2.INTER_LINEAR)
end = time.time()
print("resize spent %fs" % (end - start))
if not use_pr:
start = time.time()
im_mean = np.array(self.mean).reshape((3, 1, 1))
im_std = np.array(self.std).reshape((3, 1, 1))
# HWC -> CHW, don't use transpose((2, 0, 1))
im = im.swapaxes(1, 2)
im = im.swapaxes(0, 1)
im = im[:, :, :].astype('float32') / 255.0
im -= im_mean
im /= im_std
end = time.time()
print("preprocessing spent %fs" % (end-start))
im = im[np.newaxis,:,:,:]
im_list[idx] = im
ori_h_list[idx] = ori_h
ori_w_list[idx] = ori_w
return im, ori_h, ori_w
class Predictor:
def __init__(self, config):
self.config = config
model_file = os.path.join(config.model_path, config.model_file_name)
param_file = os.path.join(config.model_path, config.param_file_name)
if config.predictor_mode == "NATIVE":
predictor_config = NativeConfig()
predictor_config.prog_file = model_file
predictor_config.param_file = param_file
predictor_config.use_gpu = config.use_gpu
predictor_config.device = 0
predictor_config.fraction_of_gpu_memory = 0
elif config.predictor_mode == "ANALYSIS":
predictor_config = AnalysisConfig(model_file, param_file)
if config.use_gpu:
predictor_config.enable_use_gpu(100, 0)
else:
predictor_config.disable_gpu()
# need to use zero copy run
# predictor_config.switch_use_feed_fetch_ops(False)
# predictor_config.enable_tensorrt_engine(
# workspace_size=1<<30,
# max_batch_size=1,
# min_subgraph_size=3,
# precision_mode=AnalysisConfig.Precision.Int8,
# use_static=False,
# use_calib_mode=True
# )
predictor_config.switch_specify_input_names(True)
predictor_config.enable_memory_optim()
self.predictor = create_paddle_predictor(predictor_config)
self.preprocessor = PreProcessor(config)
self.threads_pool = ThreadPoolExecutor(config.batch_size)
def make_tensor(self, inputs, batch_size, use_pr=False):
im_tensor = PaddleTensor()
im_tensor.name = "image"
if not use_pr:
im_tensor.shape = [batch_size, self.config.channels,
self.config.resize[1], self.config.resize[0]]
else:
im_tensor.shape = [batch_size, self.config.resize[1],
self.config.resize[0], self.config.channels]
print(im_tensor.shape)
im_tensor.dtype = PaddleDType.FLOAT32
start = time.time()
im_tensor.data = PaddleBuf(inputs.ravel().astype("float32"))
print("flatten time: %f" % (time.time() - start))
return [im_tensor]
def output_result(self, image_name, output, ori_h, ori_w, use_pr=False):
mask = output
if not use_pr:
mask = np.argmax(output, axis=0)
mask = mask.astype('uint8')
mask_png = mask
score_png = mask_png[:, :, np.newaxis]
score_png = np.concatenate([score_png] * 3, axis=2)
for i in range(score_png.shape[0]):
for j in range(score_png.shape[1]):
score_png[i, j] = color_map[score_png[i, j, 0]]
mask_save_name = image_name + ".png"
cv2.imwrite(mask_save_name, mask_png, [cv2.CV_8UC1])
result_name = image_name + "_result.png"
result_png = score_png
# if not use_pr:
result_png = cv2.resize(result_png, (ori_w, ori_h), fx=0, fy=0,
interpolation=cv2.INTER_CUBIC)
cv2.imwrite(result_name, result_png, [cv2.CV_8UC1])
print("save result of [" + image_name + "] done.")
def predict(self, images):
batch_size = self.config.batch_size
total_runtime = 0
total_imwrite_time = 0
for i in range(0, len(images), batch_size):
start = time.time()
bz = batch_size
if i + batch_size >= len(images):
bz = len(images) - i
im_list = [0] * bz
ori_h_list = [0] * bz
ori_w_list = [0] * bz
tasks = [self.threads_pool.submit(self.preprocessor.process,
images[i + j], im_list,
ori_h_list, ori_w_list,
j, Flags.use_pr)
for j in range(bz)]
# join all running threads
for t in as_completed(tasks):
pass
input_data = np.concatenate(im_list)
input_data = self.make_tensor(input_data, bz, use_pr=Flags.use_pr)
inference_start = time.time()
output_data = self.predictor.run(input_data)[0]
end = time.time()
print("inference time = %fs " % (end - inference_start))
print("runtime = %fs " % (end - start))
total_runtime += (end - start)
output_data = output_data.as_ndarray()
output_start = time.time()
for j in range(bz):
self.output_result(images[i + j], output_data[j],
ori_h_list[j], ori_w_list[j], Flags.use_pr)
output_end = time.time()
total_imwrite_time += output_end - output_start
print("total time = %fs" % total_runtime)
print("total imwrite time = %fs" % total_imwrite_time)
def usage():
print("Usage: python infer.py --conf=/config/path/to/your/model " +
"--input_dir=/directory/of/your/input/images [--use_pr=True]")
def read_conf(conf_file):
if not os.path.exists(conf_file):
raise FileNotFoundError("Can't find the configuration file path," +
" please check whether the configuration" +
" path is correctly set.")
f = open(conf_file)
config_dict = yaml.load(f, Loader=yaml.FullLoader)
config = Config(config_dict)
return config
def read_input_dir(input_dir, ext=".jpg|.jpeg"):
if not os.path.exists(input_dir):
raise FileNotFoundError("This input directory doesn't exist, please" +
" check whether the input directory is" +
" correctly set.")
if not os.path.isdir(input_dir):
raise NotADirectoryError("This input directory in not a directory," +
" please check whether the input directory" +
" is correctly set.")
files_list = []
ext_list = ext.split("|")
files = os.listdir(input_dir)
for file in files:
for ext_suffix in ext_list:
if file.endswith(ext_suffix):
full_path = os.path.join(input_dir, file)
files_list.append(full_path)
break
return files_list
def main(argv):
# 0. parse the argument
Flags(argv)
if Flags.conf == "" or Flags.input_dir == "":
usage()
return -1
try:
# 1. get a conf dictionary
seg_deploy_configs = read_conf(Flags.conf)
# 2. get all the images path with extension '.jpeg' at input_dir
images = read_input_dir(Flags.input_dir)
if len(images) == 0:
print("No Images Found! Please check whether the images format" +
" is correct. Supporting format: [.jpeg|.jpg].")
print(images)
except Exception as e:
print(e)
return -1
# 3. init predictor and predict
seg_predictor = Predictor(seg_deploy_configs)
seg_predictor.predict(images)
if __name__ == "__main__":
main(sys.argv)
# PaddleSeg Python 预测部署方案
本说明文档旨在提供一个跨平台的图像分割模型的Python预测部署方案,用户通过一定的配置,加上少量的代码,即可把模型集成到自己的服务中,完成图像分割的任务。
## 前置条件
* Python2.7+,Python3+
* pip,pip3
## 主要目录和文件
```
inference
├── infer.py # 完成预测、可视化的Python脚本
└── requirements.txt # 预测部署脚本所依赖的库
```
### Step1:安装PaddlePaddle
可参考以下链接,选择合适版本的PaddlePaddle进行安装。[PaddlePaddle安装教程](https://www.paddlepaddle.org.cn/install/doc/)
### Step2:安装Python依赖包
在inference目录下,安装相应的Python预测依赖包
```bash
pip install -r requirements.txt
```
因为预测部署中需要使用opencv,所以还需要安装相关的动态链接库。相关操作如下:
Ubuntu 下安装相关链接库:
```bash
apt-get install -y libglib2.0-0 libsm6 libxext6 libxrender-dev
```
CentOS 下安装相关链接库:
```bash
yum install -y libXext libSM libXrender
```
### Step3:预测
在终端输入以下命令进行预测。
```
python infer.py --conf=/path/to/XXX.yaml --input_dir/path/to/images_directory --use_pr=False
```
预测使用的三个命令参数说明如下:
| 参数 | 含义 |
|-------|----------|
| conf | 模型配置的Yaml文件路径 |
| input_dir | 需要预测的图片目录 |
| use_pr | 是否使用优化模型,默认为False|
* 优化模型:对于图像分割模型,由于模型输入的数据需要使用CPU对读取的图像数据进行预处理,预处理时长较长,为了降低在使用GPU进行端到端预测时的延时,优化模型把预处理部分融入到模型当中。在使用GPU进行预测时,优化模型的预处理部分将会在GPU上进行,大大降低了端到端延时。可使用新版的模型导出工具导出优化模型。
![avatar](images/humanseg/demo2.jpeg)
输出预测结果
![avatar](images/humanseg/demo2.jpeg_result.png)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册