提交 3e19600f 编写于 作者: S sjtubinlong 提交者: sjtubinlong

add python inference

上级 dba781af
# PaddleSeg C++预测部署方案 # PaddleSeg 预测部署方案
[1.说明](#1说明) [1.说明](#1说明)
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
## 1.说明 ## 1.说明
本目录提供一个跨平台的图像分割模型的C++预测部署方案,用户通过一定的配置,加上少量的代码,即可把模型集成到自己的服务中,完成图像分割的任务。 本目录提供一个跨平台的图像分割模型的C++、Python预测部署方案,用户通过一定的配置,加上少量的代码,即可把模型集成到自己的服务中,完成图像分割的任务。
主要设计的目标包括以下四点: 主要设计的目标包括以下四点:
- 跨平台,支持在 windows 和 Linux 完成编译、开发和部署 - 跨平台,支持在 windows 和 Linux 完成编译、开发和部署
...@@ -57,6 +57,8 @@ inference ...@@ -57,6 +57,8 @@ inference
`Windows`上推荐使用最新的`Visual Studio 2019 Community`直接编译`CMake`项目。 `Windows`上推荐使用最新的`Visual Studio 2019 Community`直接编译`CMake`项目。
针对Python的预测部署方法,可参考以下链接:[Python预测部署方法](python_inference.md)
## 4.预测并可视化结果 ## 4.预测并可视化结果
完成编译后,便生成了需要的可执行文件和链接库,然后执行以下步骤: 完成编译后,便生成了需要的可执行文件和链接库,然后执行以下步骤:
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import gflags
import yaml
import numpy as np
import cv2
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import NativeConfig
from paddle.fluid.core import create_paddle_predictor
from paddle.fluid.core import PaddleTensor
from paddle.fluid.core import PaddleBuf
from paddle.fluid.core import PaddleDType
from concurrent.futures import ThreadPoolExecutor, as_completed
gflags.DEFINE_string("conf", default="", help="Configuration File Path")
gflags.DEFINE_string("input_dir", default="", help="Directory of Input Images")
gflags.DEFINE_boolean("use_pr", default=False, help="Use optimized model")
Flags = gflags.FLAGS
# ColorMap for visualization more clearly
color_map = [[128, 64, 128], [244, 35, 231], [69, 69, 69], [102, 102, 156],
[190, 153, 153], [153, 153, 153], [250, 170, 29], [219, 219, 0],
[106, 142, 35], [152, 250, 152], [69, 129, 180], [219, 19, 60],
[255, 0, 0], [0, 0, 142], [0, 0, 69], [0, 60, 100], [0, 79, 100],
[0, 0, 230], [119, 10, 32]]
class ConfItemsNotFoundError(Exception):
def __init__(self, message):
super().__init__(message + " item not Found")
class Config:
def __init__(self, config_dict):
if "DEPLOY" not in config_dict:
raise ConfItemsNotFoundError("DEPLOY")
deploy_dict = config_dict["DEPLOY"]
if "EVAL_CROP_SIZE" not in deploy_dict:
raise ConfItemsNotFoundError("EVAL_CROP_SIZE")
# 1. get resize
self.resize = [int(value) for value in
deploy_dict["EVAL_CROP_SIZE"].strip("()").split(",")]
# 2. get mean
if "MEAN" not in deploy_dict:
raise ConfItemsNotFoundError("MEAN")
self.mean = deploy_dict["MEAN"]
# 3. get std
if "STD" not in deploy_dict:
raise ConfItemsNotFoundError("STD")
self.std = deploy_dict["STD"]
# 4. get image type
if "IMAGE_TYPE" not in deploy_dict:
raise ConfItemsNotFoundError("IMAGE_TYPE")
self.img_type = deploy_dict["IMAGE_TYPE"]
# 5. get class number
if "NUM_CLASSES" not in deploy_dict:
raise ConfItemsNotFoundError("NUM_CLASSES")
self.class_num = deploy_dict["NUM_CLASSES"]
# 7. set model path
if "MODEL_PATH" not in deploy_dict:
raise ConfItemsNotFoundError("MODEL_PATH")
self.model_path = deploy_dict["MODEL_PATH"]
# 8. get model file_name
if "MODEL_FILENAME" not in deploy_dict:
self.model_file_name = "__model__"
else:
self.model_file_name = deploy_dict["MODEL_FILENAME"]
# 9. get model param file name
if "PARAMS_FILENAME" not in deploy_dict:
self.param_file_name = "__params__"
else:
self.param_file_name = deploy_dict["PARAMS_FILENAME"]
# 10. get pre_processor
if "PRE_PROCESSOR" not in deploy_dict:
raise ConfItemsNotFoundError("PRE_PROCESSOR")
self.pre_processor = deploy_dict["PRE_PROCESSOR"]
# 11. use_gpu
if "USE_GPU" not in deploy_dict:
self.use_gpu = 0
else:
self.use_gpu = deploy_dict["USE_GPU"]
# 12. predictor_mode
if "PREDICTOR_MODE" not in deploy_dict:
raise ConfItemsNotFoundError("PREDICTOR_MODE")
self.predictor_mode = deploy_dict["PREDICTOR_MODE"]
# 13. batch_size
if "BATCH_SIZE" not in deploy_dict:
raise ConfItemsNotFoundError("BATCH_SIZE")
self.batch_size = deploy_dict["BATCH_SIZE"]
# 14. channels
if "CHANNELS" not in deploy_dict:
raise ConfItemsNotFoundError("CHANNELS")
self.channels = deploy_dict["CHANNELS"]
class PreProcessor:
def __init__(self, config):
self.resize_size = (config.resize[0], config.resize[1])
self.mean = config.mean
self.std = config.std
def process(self, image_file, im_list, ori_h_list, ori_w_list, idx, use_pr=False):
start = time.time()
im = cv2.imread(image_file, -1)
end = time.time()
print("imread spent %fs" % (end - start))
channels = im.shape[2]
ori_h = im.shape[0]
ori_w = im.shape[1]
if channels == 1:
im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
channels = im.shape[2]
if channels != 3 and channels != 4:
print("Only support rgb(gray) or rgba image.")
return -1
if ori_h != self.resize_size[0] or ori_w != self.resize_size[1]:
start = time.time()
im = cv2.resize(im, self.resize_size, fx=0, fy=0, interpolation=cv2.INTER_LINEAR)
end = time.time()
print("resize spent %fs" % (end - start))
if not use_pr:
start = time.time()
im_mean = np.array(self.mean).reshape((3, 1, 1))
im_std = np.array(self.std).reshape((3, 1, 1))
# HWC -> CHW, don't use transpose((2, 0, 1))
im = im.swapaxes(1, 2)
im = im.swapaxes(0, 1)
im = im[:, :, :].astype('float32') / 255.0
im -= im_mean
im /= im_std
end = time.time()
print("preprocessing spent %fs" % (end-start))
im = im[np.newaxis,:,:,:]
im_list[idx] = im
ori_h_list[idx] = ori_h
ori_w_list[idx] = ori_w
return im, ori_h, ori_w
class Predictor:
def __init__(self, config):
self.config = config
model_file = os.path.join(config.model_path, config.model_file_name)
param_file = os.path.join(config.model_path, config.param_file_name)
if config.predictor_mode == "NATIVE":
predictor_config = NativeConfig()
predictor_config.prog_file = model_file
predictor_config.param_file = param_file
predictor_config.use_gpu = config.use_gpu
predictor_config.device = 0
predictor_config.fraction_of_gpu_memory = 0
elif config.predictor_mode == "ANALYSIS":
predictor_config = AnalysisConfig(model_file, param_file)
if config.use_gpu:
predictor_config.enable_use_gpu(100, 0)
else:
predictor_config.disable_gpu()
# need to use zero copy run
# predictor_config.switch_use_feed_fetch_ops(False)
# predictor_config.enable_tensorrt_engine(
# workspace_size=1<<30,
# max_batch_size=1,
# min_subgraph_size=3,
# precision_mode=AnalysisConfig.Precision.Int8,
# use_static=False,
# use_calib_mode=True
# )
predictor_config.switch_specify_input_names(True)
predictor_config.enable_memory_optim()
self.predictor = create_paddle_predictor(predictor_config)
self.preprocessor = PreProcessor(config)
self.threads_pool = ThreadPoolExecutor(config.batch_size)
def make_tensor(self, inputs, batch_size, use_pr=False):
im_tensor = PaddleTensor()
im_tensor.name = "image"
if not use_pr:
im_tensor.shape = [batch_size, self.config.channels,
self.config.resize[1], self.config.resize[0]]
else:
im_tensor.shape = [batch_size, self.config.resize[1],
self.config.resize[0], self.config.channels]
print(im_tensor.shape)
im_tensor.dtype = PaddleDType.FLOAT32
start = time.time()
im_tensor.data = PaddleBuf(inputs.ravel().astype("float32"))
print("flatten time: %f" % (time.time() - start))
return [im_tensor]
def output_result(self, image_name, output, ori_h, ori_w, use_pr=False):
mask = output
if not use_pr:
mask = np.argmax(output, axis=0)
mask = mask.astype('uint8')
mask_png = mask
score_png = mask_png[:, :, np.newaxis]
score_png = np.concatenate([score_png] * 3, axis=2)
for i in range(score_png.shape[0]):
for j in range(score_png.shape[1]):
score_png[i, j] = color_map[score_png[i, j, 0]]
mask_save_name = image_name + ".png"
cv2.imwrite(mask_save_name, mask_png, [cv2.CV_8UC1])
result_name = image_name + "_result.png"
result_png = score_png
# if not use_pr:
result_png = cv2.resize(result_png, (ori_w, ori_h), fx=0, fy=0,
interpolation=cv2.INTER_CUBIC)
cv2.imwrite(result_name, result_png, [cv2.CV_8UC1])
print("save result of [" + image_name + "] done.")
def predict(self, images):
batch_size = self.config.batch_size
total_runtime = 0
total_imwrite_time = 0
for i in range(0, len(images), batch_size):
start = time.time()
bz = batch_size
if i + batch_size >= len(images):
bz = len(images) - i
im_list = [0] * bz
ori_h_list = [0] * bz
ori_w_list = [0] * bz
tasks = [self.threads_pool.submit(self.preprocessor.process,
images[i + j], im_list,
ori_h_list, ori_w_list,
j, Flags.use_pr)
for j in range(bz)]
# join all running threads
for t in as_completed(tasks):
pass
input_data = np.concatenate(im_list)
input_data = self.make_tensor(input_data, bz, use_pr=Flags.use_pr)
inference_start = time.time()
output_data = self.predictor.run(input_data)[0]
end = time.time()
print("inference time = %fs " % (end - inference_start))
print("runtime = %fs " % (end - start))
total_runtime += (end - start)
output_data = output_data.as_ndarray()
output_start = time.time()
for j in range(bz):
self.output_result(images[i + j], output_data[j],
ori_h_list[j], ori_w_list[j], Flags.use_pr)
output_end = time.time()
total_imwrite_time += output_end - output_start
print("total time = %fs" % total_runtime)
print("total imwrite time = %fs" % total_imwrite_time)
def usage():
print("Usage: python infer.py --conf=/config/path/to/your/model " +
"--input_dir=/directory/of/your/input/images [--use_pr=True]")
def read_conf(conf_file):
if not os.path.exists(conf_file):
raise FileNotFoundError("Can't find the configuration file path," +
" please check whether the configuration" +
" path is correctly set.")
f = open(conf_file)
config_dict = yaml.load(f, Loader=yaml.FullLoader)
config = Config(config_dict)
return config
def read_input_dir(input_dir, ext=".jpg|.jpeg"):
if not os.path.exists(input_dir):
raise FileNotFoundError("This input directory doesn't exist, please" +
" check whether the input directory is" +
" correctly set.")
if not os.path.isdir(input_dir):
raise NotADirectoryError("This input directory in not a directory," +
" please check whether the input directory" +
" is correctly set.")
files_list = []
ext_list = ext.split("|")
files = os.listdir(input_dir)
for file in files:
for ext_suffix in ext_list:
if file.endswith(ext_suffix):
full_path = os.path.join(input_dir, file)
files_list.append(full_path)
break
return files_list
def main(argv):
# 0. parse the argument
Flags(argv)
if Flags.conf == "" or Flags.input_dir == "":
usage()
return -1
try:
# 1. get a conf dictionary
seg_deploy_configs = read_conf(Flags.conf)
# 2. get all the images path with extension '.jpeg' at input_dir
images = read_input_dir(Flags.input_dir)
if len(images) == 0:
print("No Images Found! Please check whether the images format" +
" is correct. Supporting format: [.jpeg|.jpg].")
print(images)
except Exception as e:
print(e)
return -1
# 3. init predictor and predict
seg_predictor = Predictor(seg_deploy_configs)
seg_predictor.predict(images)
if __name__ == "__main__":
main(sys.argv)
# PaddleSeg Python 预测部署方案
本说明文档旨在提供一个跨平台的图像分割模型的Python预测部署方案,用户通过一定的配置,加上少量的代码,即可把模型集成到自己的服务中,完成图像分割的任务。
## 前置条件
* Python2.7+,Python3+
* pip,pip3
## 主要目录和文件
```
inference
├── infer.py # 完成预测、可视化的Python脚本
└── requirements.txt # 预测部署脚本所依赖的库
```
### Step1:安装PaddlePaddle
可参考以下链接,选择合适版本的PaddlePaddle进行安装。[PaddlePaddle安装教程](https://www.paddlepaddle.org.cn/install/doc/)
### Step2:安装Python依赖包
在inference目录下,安装相应的Python预测依赖包
```bash
pip install -r requirements.txt
```
因为预测部署中需要使用opencv,所以还需要安装相关的动态链接库。相关操作如下:
Ubuntu 下安装相关链接库:
```bash
apt-get install -y libglib2.0-0 libsm6 libxext6 libxrender-dev
```
CentOS 下安装相关链接库:
```bash
yum install -y libXext libSM libXrender
```
### Step3:预测
在终端输入以下命令进行预测。
```
python infer.py --conf=/path/to/XXX.yaml --input_dir/path/to/images_directory --use_pr=False
```
预测使用的三个命令参数说明如下:
| 参数 | 含义 |
|-------|----------|
| conf | 模型配置的Yaml文件路径 |
| input_dir | 需要预测的图片目录 |
| use_pr | 是否使用优化模型,默认为False|
* 优化模型:对于图像分割模型,由于模型输入的数据需要使用CPU对读取的图像数据进行预处理,预处理时长较长,为了降低在使用GPU进行端到端预测时的延时,优化模型把预处理部分融入到模型当中。在使用GPU进行预测时,优化模型的预处理部分将会在GPU上进行,大大降低了端到端延时。可使用新版的模型导出工具导出优化模型。
![avatar](images/humanseg/demo2.jpeg)
输出预测结果
![avatar](images/humanseg/demo2.jpeg_result.png)
\ No newline at end of file
python-gflags
pyyaml
numpy
opencv-python
futures
\ No newline at end of file
...@@ -70,43 +70,108 @@ class PaddleSegModelConfigPaser { ...@@ -70,43 +70,108 @@ class PaddleSegModelConfigPaser {
bool load_config(const std::string& conf_file) { bool load_config(const std::string& conf_file) {
reset(); reset();
YAML::Node config;
YAML::Node config = YAML::LoadFile(conf_file); try {
config = YAML::LoadFile(conf_file);
} catch(...) {
return false;
}
// 1. get resize // 1. get resize
auto str = config["DEPLOY"]["EVAL_CROP_SIZE"].as<std::string>(); if (config["DEPLOY"]["EVAL_CROP_SIZE"].IsDefined()) {
_resize = parse_str_to_vec<int>(process_parenthesis(str)); auto str = config["DEPLOY"]["EVAL_CROP_SIZE"].as<std::string>();
_resize = parse_str_to_vec<int>(process_parenthesis(str));
} else {
std::cerr << "Please set EVAL_CROP_SIZE: (xx, xx)" << std::endl;
return false;
}
// 2. get mean // 2. get mean
for (const auto& item : config["DEPLOY"]["MEAN"]) { if (config["DEPLOY"]["MEAN"].IsDefined()) {
_mean.push_back(item.as<float>()); for (const auto& item : config["DEPLOY"]["MEAN"]) {
_mean.push_back(item.as<float>());
}
} else {
std::cerr << "Please set MEAN: [xx, xx, xx]" << std::endl;
return false;
} }
// 3. get std // 3. get std
for (const auto& item : config["DEPLOY"]["STD"]) { if(config["DEPLOY"]["STD"].IsDefined()) {
_std.push_back(item.as<float>()); for (const auto& item : config["DEPLOY"]["STD"]) {
_std.push_back(item.as<float>());
}
} else {
std::cerr << "Please set STD: [xx, xx, xx]" << std::endl;
return false;
} }
// 4. get image type // 4. get image type
_img_type = config["DEPLOY"]["IMAGE_TYPE"].as<std::string>(); if (config["DEPLOY"]["IMAGE_TYPE"].IsDefined()) {
_img_type = config["DEPLOY"]["IMAGE_TYPE"].as<std::string>();
} else {
std::cerr << "Please set IMAGE_TYPE: \"rgb\" or \"rgba\"" << std::endl;
return false;
}
// 5. get class number // 5. get class number
_class_num = config["DEPLOY"]["NUM_CLASSES"].as<int>(); if (config["DEPLOY"]["NUM_CLASSES"].IsDefined()) {
_class_num = config["DEPLOY"]["NUM_CLASSES"].as<int>();
} else {
std::cerr << "Please set NUM_CLASSES: x" << std::endl;
return false;
}
// 7. set model path // 7. set model path
_model_path = config["DEPLOY"]["MODEL_PATH"].as<std::string>(); if (config["DEPLOY"]["MODEL_PATH"].IsDefined()) {
_model_path = config["DEPLOY"]["MODEL_PATH"].as<std::string>();
} else {
std::cerr << "Please set MODEL_PATH: \"/path/to/model_dir\"" << std::endl;
return false;
}
// 8. get model file_name // 8. get model file_name
_model_file_name = config["DEPLOY"]["MODEL_FILENAME"].as<std::string>(); if (config["DEPLOY"]["MODEL_FILENAME"].IsDefined()) {
_model_file_name = config["DEPLOY"]["MODEL_FILENAME"].as<std::string>();
} else {
_model_file_name = "__model__";
}
// 9. get model param file name // 9. get model param file name
_param_file_name = if (config["DEPLOY"]["PARAMS_FILENAME"].IsDefined()) {
config["DEPLOY"]["PARAMS_FILENAME"].as<std::string>(); _param_file_name
= config["DEPLOY"]["PARAMS_FILENAME"].as<std::string>();
} else {
_param_file_name = "__params__";
}
// 10. get pre_processor // 10. get pre_processor
_pre_processor = config["DEPLOY"]["PRE_PROCESSOR"].as<std::string>(); if (config["DEPLOY"]["PRE_PROCESSOR"].IsDefined()) {
_pre_processor = config["DEPLOY"]["PRE_PROCESSOR"].as<std::string>();
} else {
std::cerr << "Please set PRE_PROCESSOR: \"DetectionPreProcessor\"" << std::endl;
return false;
}
// 11. use_gpu // 11. use_gpu
_use_gpu = config["DEPLOY"]["USE_GPU"].as<int>(); if (config["DEPLOY"]["USE_GPU"].IsDefined()) {
_use_gpu = config["DEPLOY"]["USE_GPU"].as<int>();
} else {
_use_gpu = 0;
}
// 12. predictor_mode // 12. predictor_mode
_predictor_mode = config["DEPLOY"]["PREDICTOR_MODE"].as<std::string>(); if (config["DEPLOY"]["PREDICTOR_MODE"].IsDefined()) {
_predictor_mode = config["DEPLOY"]["PREDICTOR_MODE"].as<std::string>();
} else {
std::cerr << "Please set PREDICTOR_MODE: \"NATIVE\" or \"ANALYSIS\"" << std::endl;
return false;
}
// 13. batch_size // 13. batch_size
_batch_size = config["DEPLOY"]["BATCH_SIZE"].as<int>(); if (config["DEPLOY"]["BATCH_SIZE"].IsDefined()) {
_batch_size = config["DEPLOY"]["BATCH_SIZE"].as<int>();
} else {
_batch_size = 1;
}
// 14. channels // 14. channels
_channels = config["DEPLOY"]["CHANNELS"].as<int>(); if (config["DEPLOY"]["CHANNELS"].IsDefined()) {
_channels = config["DEPLOY"]["CHANNELS"].as<int>();
} else {
std::cerr << "Please set CHANNELS: x" << std::endl;
return false;
}
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册