未验证 提交 7a4275b6 编写于 作者: W wangguanzhong 提交者: GitHub

[Dygraph]add export_model and deploy (#1762)

* add export_model and deploy

* fix travis-ci

* update CMakeList & linux doc

* update by comments
上级 dcf97ccd
......@@ -20,7 +20,7 @@ addons:
before_install:
- sudo pip install -U virtualenv pre-commit pip
- docker pull paddlepaddle/paddle:latest
- git pull https://github.com/PaddlePaddle/PaddleDetection dygraph -r
- git pull https://github.com/PaddlePaddle/PaddleDetection dygraph
script:
- exit_code=0
......
......@@ -42,7 +42,7 @@ TestReader:
fields: ['image', 'im_shape', 'scale_factor', 'im_id']
sample_transforms:
- DecodeOp: {}
- ResizeOp: {target_size: [608, 608], interp: 2}
- ResizeOp: {target_size: [608, 608], keep_ratio: False, interp: 2}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- PermuteOp: {}
batch_size: 1
# PaddleDetection 预测部署
`PaddleDetection`目前支持使用`Python``C++`部署在`Windows``Linux` 上运行。
`PaddleDetection`目前支持:
- 使用`Python``C++`部署在`Windows``Linux` 上运行
- [在线服务化部署](./serving/README.md)
- [移动端部署](https://github.com/PaddlePaddle/Paddle-Lite-Demo)
## 模型导出
训练得到一个满足要求的模型后,如果想要将该模型接入到C++服务器端预测库或移动端预测库,需要通过`tools/export_model.py`导出该模型。
- [导出教程](../docs/advanced_tutorials/deploy/EXPORT_MODEL.md)
- [导出教程](https://github.com/PaddlePaddle/PaddleDetection/blob/master/docs/advanced_tutorials/deploy/EXPORT_MODEL.md)
模型导出后, 目录结构如下(以`yolov3_darknet`为例):
```
......@@ -18,6 +21,8 @@ yolov3_darknet # 模型目录
预测时,该目录所在的路径会作为程序的输入参数。
## 预测部署
- [1. Python预测(支持 Linux 和 Windows)](./python/)
- [2. C++预测(支持 Linux 和 Windows)](./cpp/)
- [3. 移动端部署参考Paddle-Lite文档](https://paddle-lite.readthedocs.io/zh/latest/)
- [1. Python预测(支持 Linux 和 Windows)](https://github.com/PaddlePaddle/PaddleDetection/blob/master/deploy/python)
- [2. C++预测(支持 Linux 和 Windows)](https://github.com/PaddlePaddle/PaddleDetection/blob/master/deploy/cpp)
- [3. 在线服务化部署](./serving/README.md)
- [4. 移动端部署](https://github.com/PaddlePaddle/Paddle-Lite-Demo)
- [5. Jetson设备部署](./cpp/docs/Jetson_build.md)
......@@ -10,7 +10,8 @@ SET(PADDLE_DIR "" CACHE PATH "Location of libraries")
SET(OPENCV_DIR "" CACHE PATH "Location of libraries")
SET(CUDA_LIB "" CACHE PATH "Location of libraries")
SET(CUDNN_LIB "" CACHE PATH "Location of libraries")
SET(TENSORRT_DIR "" CACHE PATH "Compile demo with TensorRT")
SET(TENSORRT_INC_DIR "" CACHE PATH "Compile demo with TensorRT")
SET(TENSORRT_LIB_DIR "" CACHE PATH "Compile demo with TensorRT")
include(cmake/yaml-cpp.cmake)
......@@ -112,8 +113,8 @@ endif()
if (NOT WIN32)
if (WITH_TENSORRT AND WITH_GPU)
include_directories("${TENSORRT_DIR}/include")
link_directories("${TENSORRT_DIR}/lib")
include_directories("${TENSORRT_INC_DIR}/")
link_directories("${TENSORRT_LIB_DIR}/")
endif()
endif(NOT WIN32)
......@@ -195,8 +196,8 @@ endif(NOT WIN32)
if(WITH_GPU)
if(NOT WIN32)
if (WITH_TENSORRT)
set(DEPS ${DEPS} ${TENSORRT_DIR}/lib/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${TENSORRT_DIR}/lib/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX})
endif()
set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${CUDNN_LIB}/libcudnn${CMAKE_SHARED_LIBRARY_SUFFIX})
......
......@@ -52,7 +52,7 @@ deploy/cpp
## 3.编译部署
### 3.1 导出模型
请确认您已经基于`PaddleDetection`[export_model.py](../../tools/export_model.py)导出您的模型,并妥善保存到合适的位置。导出模型细节请参考 [导出模型教程](../../docs/advanced_tutorials/deploy/EXPORT_MODEL.md)
请确认您已经基于`PaddleDetection`[export_model.py](https://github.com/PaddlePaddle/PaddleDetection/blob/master/tools/export_model.py)导出您的模型,并妥善保存到合适的位置。导出模型细节请参考 [导出模型教程](https://github.com/PaddlePaddle/PaddleDetection/blob/master/docs/advanced_tutorials/deploy/EXPORT_MODEL.md)
模型导出后, 目录结构如下(以`yolov3_darknet`为例):
```
......@@ -67,5 +67,5 @@ yolov3_darknet # 模型目录
### 3.2 编译
仅支持在`Windows``Linux`平台编译和使用
- [Linux 编译指南](./docs/linux_build.md)
- [Windows编译指南(使用Visual Studio 2019)](./docs/windows_vs2019_build.md)
- [Linux 编译指南](https://github.com/PaddlePaddle/PaddleDetection/blob/master/deploy/cpp/docs/linux_build.md)
- [Windows编译指南(使用Visual Studio 2019)](https://github.com/PaddlePaddle/PaddleDetection/blob/master/deploy/cpp/docs/windows_vs2019_build.md)
# Linux平台编译指南
## 说明
本文档在 `Linux`平台使用`GCC 4.8.5``GCC 4.9.4`测试过,如果需要使用更高G++版本编译使用,则需要重新编译Paddle预测库,请参考: [从源码编译Paddle预测库](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html)
本文档在 `Linux`平台使用`GCC 4.8.5``GCC 4.9.4`测试过,如果需要使用更高G++版本编译使用,则需要重新编译Paddle预测库,请参考: [从源码编译Paddle预测库](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html)本文档使用的预置的opencv库是在ubuntu 16.04上用gcc4.8编译的,如果需要在ubuntu 16.04以外的系统环境编译,那么需自行编译opencv库。
## 前置条件
* G++ 4.8.2 ~ 4.9.4
......@@ -40,38 +40,43 @@ fluid_inference
编译`cmake`的命令在`scripts/build.sh`中,请根据实际情况修改主要参数,其主要内容说明如下:
```
# 是否使用GPU(即是否使用 CUDA)
WITH_GPU=OFF
# 使用MKL or openblas
WITH_MKL=ON
# 是否集成 TensorRT(仅WITH_GPU=ON 有效)
WITH_TENSORRT=OFF
# TensorRT 的include路径
TENSORRT_LIB_DIR=/path/to/TensorRT/include
# TensorRT 的lib路径
TENSORRT_DIR=/path/to/TensorRT/
TENSORRT_LIB_DIR=/path/to/TensorRT/lib
# Paddle 预测库路径
PADDLE_DIR=/path/to/fluid_inference/
PADDLE_DIR=/path/to/fluid_inference
# Paddle 的预测库是否使用静态库来编译
# 使用TensorRT时,Paddle的预测库通常为动态库
WITH_STATIC_LIB=OFF
# CUDA 的 lib 路径
CUDA_LIB=/path/to/cuda/lib/
CUDA_LIB=/path/to/cuda/lib
# CUDNN 的 lib 路径
CUDNN_LIB=/path/to/cudnn/lib/
CUDNN_LIB=/path/to/cudnn/lib
# OPENCV 路径, 如果使用自带预编译版本可不修改
sh $(pwd)/scripts/bootstrap.sh # 下载预编译版本的opencv
OPENCV_DIR=$(pwd)/deps/opencv3gcc4.8/
# 请检查以上各个路径是否正确
# 以下无需改动
rm -rf build
mkdir -p build
cd build
cmake .. \
-DWITH_GPU=${WITH_GPU} \
-DWITH_MKL=${WITH_MKL} \
-DWITH_TENSORRT=${WITH_TENSORRT} \
-DTENSORRT_DIR=${TENSORRT_DIR} \
-DTENSORRT_LIB_DIR=${TENSORRT_LIB_DIR} \
-DTENSORRT_INC_DIR=${TENSORRT_INC_DIR} \
-DPADDLE_DIR=${PADDLE_DIR} \
-DWITH_STATIC_LIB=${WITH_STATIC_LIB} \
-DCUDA_LIB=${CUDA_LIB} \
......@@ -86,18 +91,23 @@ make
sh ./scripts/build.sh
```
**注意**: OPENCV依赖OPENBLAS,Ubuntu用户需确认系统是否已存在`libopenblas.so`。如未安装,可执行apt-get install libopenblas-dev进行安装。
### Step5: 预测及可视化
编译成功后,预测入口程序为`build/main`其主要命令参数说明如下:
| 参数 | 说明 |
| ---- | ---- |
| model_dir | 导出的预测模型所在路径 |
| image_path | 要预测的图片文件路径 |
| video_path | 要预测的视频文件路径 |
| use_gpu | 是否使用 GPU 预测, 支持值为0或1(默认值为0)|
| --run_mode |使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16)|
| --model_dir | 导出的预测模型所在路径 |
| --image_path | 要预测的图片文件路径 |
| --video_path | 要预测的视频文件路径 |
| --camera_id | Option | 用来预测的摄像头ID,默认为-1(表示不使用摄像头预测)|
| --use_gpu | 是否使用 GPU 预测, 支持值为0或1(默认值为0)|
| --gpu_id | 指定进行推理的GPU device id(默认值为0)|
| --run_mode | 使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16)|
| --run_benchmark | 是否重复预测来进行benchmark测速 |
| --output_dir | 输出图片所在的文件夹, 默认为output |
**注意**如果同时设置了`video_path``image_path`,程序仅预测`video_path`
**注意**: 如果同时设置了`video_path``image_path`,程序仅预测`video_path`
`样例一`
......@@ -106,12 +116,12 @@ make
./build/main --model_dir=/root/projects/models/yolov3_darknet --image_path=/root/projects/images/test.jpeg
```
图片文件`可视化预测结果`会保存在当前目录下`output.jpeg`文件中。
图片文件`可视化预测结果`会保存在当前目录下`output.jpg`文件中。
`样例二`:
```shell
#使用 `GPU`预测视频`/root/projects/videos/test.avi`
./build/main --model_dir=/root/projects/models/yolov3_darknet --video_path=/root/projects/images/test.avi --use_gpu=1
#使用 `GPU`预测视频`/root/projects/videos/test.mp4`
./build/main --model_dir=/root/projects/models/yolov3_darknet --video_path=/root/projects/images/test.mp4 --use_gpu=1
```
视频文件`可视化预测结果`会保存在当前目录下`output.avi`文件中。
视频文件目前支持`.mp4`格式的预测,`可视化预测结果`会保存在当前目录下`output.mp4`文件中。
......@@ -4,9 +4,9 @@ Windows 平台下,我们使用`Visual Studio 2019 Community` 进行了测试
## 前置条件
* Visual Studio 2019
* Visual Studio 2019 (根据Paddle预测库所使用的VS版本选择,请参考 [Visual Studio 不同版本二进制兼容性](https://docs.microsoft.com/zh-cn/cpp/porting/binary-compat-2015-2017?view=vs-2019) )
* CUDA 9.0 / CUDA 10.0,cudnn 7+ (仅在使用GPU版本的预测库时需要)
* CMake 3.0+
* CMake 3.0+ [CMake下载](https://cmake.org/download/)
请确保系统已经安装好上述基本软件,我们使用的是`VS2019`的社区版。
......@@ -40,12 +40,14 @@ fluid_inference
1. 在OpenCV官网下载适用于Windows平台的3.4.6版本, [下载地址](https://sourceforge.net/projects/opencvlibrary/files/3.4.6/opencv-3.4.6-vc14_vc15.exe/download)
2. 运行下载的可执行文件,将OpenCV解压至指定目录,如`D:\projects\opencv`
3. 配置环境变量,如下流程所示
3. 配置环境变量,如下流程所示(如果使用全局绝对路径,可以不用设置环境变量)
- 我的电脑->属性->高级系统设置->环境变量
- 在系统变量中找到Path(如没有,自行创建),并双击编辑
- 新建,将opencv路径填入并保存,如`D:\projects\opencv\build\x64\vc14\bin`
### Step4: 使用Visual Studio 2019直接编译CMake
### Step4: 编译
#### 通过图形化操作编译CMake
1. 打开Visual Studio 2019 Community,点击`继续但无需代码`
![step2](https://paddleseg.bj.bcebos.com/inference/vs2019_step1.png)
......@@ -60,14 +62,14 @@ fluid_inference
![step3](https://paddleseg.bj.bcebos.com/inference/vs2019_step4.png)
4. 点击`浏览`,分别设置编译选项指定`CUDA``OpenCV``Paddle预测库`的路径
4. 点击`浏览`,分别设置编译选项指定`CUDA``CUDNN_LIB``OpenCV``Paddle预测库`的路径
三个编译参数的含义说明如下(带*表示仅在使用**GPU版本**预测库时指定, 其中CUDA库版本尽量对齐,**使用9.0、10.0版本,不使用9.2、10.1等版本CUDA库**):
| 参数名 | 含义 |
| ---- | ---- |
| *CUDA_LIB | CUDA的库路径 |
| CUDNN_LIB | CUDNN的库路径 |
| *CUDNN_LIB | CUDNN的库路径 |
| OPENCV_DIR | OpenCV的安装路径, |
| PADDLE_DIR | Paddle预测库的路径 |
......@@ -81,6 +83,26 @@ fluid_inference
![step6](https://paddleseg.bj.bcebos.com/inference/vs2019_step6.png)
#### 通过命令行操作编译CMake
1. 进入到`cpp`文件夹
```
cd D:\projects\PaddleDetection\deploy\cpp
```
2. 使用CMake生成项目文件
```
cmake . -G "Visual Studio 16 2019" -A x64 -T host=x64 -DWITH_GPU=ON -DWITH_MKL=ON -DCMAKE_BUILD_TYPE=Release -DCUDA_LIB=path_to_cuda_lib -DCUDNN_LIB=path_to_cudnn_lib -DPADDLE_DIR=path_to_paddle_lib -DOPENCV_DIR=path_to_opencv
```
例如:
```
cmake . -G "Visual Studio 16 2019" -A x64 -T host=x64 -DWITH_GPU=ON -DWITH_MKL=ON -DCMAKE_BUILD_TYPE=Release -DCUDA_LIB=D:\projects\packages\cuda10_0\lib\x64 -DCUDNN_LIB=D:\projects\packages\cuda10_0\lib\x64 -DPADDLE_DIR=D:\projects\packages\fluid_inference -DOPENCV_DIR=D:\projects\packages\opencv3_4_6
```
3. 编译
`Visual Studio 16 2019`打开`cpp`文件夹下的`PaddleObjectDetector.sln`,点击`生成`->`全部生成`
### Step5: 预测及可视化
上述`Visual Studio 2019`编译产出的可执行文件在`out\build\x64-Release`目录下,打开`cmd`,并切换到该目录:
......@@ -92,12 +114,19 @@ cd D:\projects\PaddleDetection\deploy\cpp\out\build\x64-Release
| 参数 | 说明 |
| ---- | ---- |
| model_dir | 导出的预测模型所在路径 |
| image_path | 要预测的图片文件路径 |
| video_path | 要预测的视频文件路径 |
| use_gpu | 是否使用 GPU 预测, 支持值为0或1(默认值为0)|
| --model_dir | 导出的预测模型所在路径 |
| --image_path | 要预测的图片文件路径 |
| --video_path | 要预测的视频文件路径 |
| --camera_id | Option | 用来预测的摄像头ID,默认为-1(表示不使用摄像头预测)|
| --use_gpu | 是否使用 GPU 预测, 支持值为0或1(默认值为0)|
| --gpu_id | 指定进行推理的GPU device id(默认值为0)|
| --run_mode | 使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16)|
| --run_benchmark | 是否重复预测来进行benchmark测速 |
| --output_dir | 输出图片所在的文件夹, 默认为output |
**注意**:如果同时设置了`video_path``image_path`,程序仅预测`video_path`
**注意**
(1)如果同时设置了`video_path``image_path`,程序仅预测`video_path`
(2)如果提示找不到`opencv_world346.dll`,把`D:\projects\packages\opencv3_4_6\build\x64\vc14\bin`文件夹下的`opencv_world346.dll`拷贝到`main.exe`文件夹下即可。
`样例一`
......@@ -106,13 +135,31 @@ cd D:\projects\PaddleDetection\deploy\cpp\out\build\x64-Release
.\main --model_dir=D:\\models\\yolov3_darknet --image_path=D:\\images\\test.jpeg
```
图片文件`可视化预测结果`会保存在当前目录下`output.jpeg`文件中。
图片文件`可视化预测结果`会保存在当前目录下`output.jpg`文件中。
`样例二`:
```shell
#使用`GPU`测试视频 `D:\\videos\\test.avi`
.\main --model_dir=D:\\models\\yolov3_darknet --video_path=D:\\videos\\test.jpeg --use_gpu=1
#使用`GPU`测试视频 `D:\\videos\\test.mp4`
.\main --model_dir=D:\\models\\yolov3_darknet --video_path=D:\\videos\\test.mp4 --use_gpu=1
```
视频文件`可视化预测结果`会保存在当前目录下`output.avi`文件中。
视频文件目前支持`.mp4`格式的预测,`可视化预测结果`会保存在当前目录下`output.mp4`文件中。
## 性能测试
测试环境为:系统: Windows 10专业版系统,CPU: I9-9820X, GPU: GTX 2080 Ti,Paddle预测库: 1.8.4,CUDA: 10.0, CUDNN: 7.4.
去掉前100轮warmup时间,测试100轮的平均时间,单位ms/image,只计算模型运行时间,不包括数据的处理和拷贝。
|模型 | AnalysisPredictor(ms) | 输入|
|---|----|---|
| YOLOv3-MobileNetv1 | 41.51 | 608*608
| faster_rcnn_r50_1x | 194.47 | 1333*1333
| faster_rcnn_r50_vd_fpn_2x | 43.35 | 1344*1344
| mask_rcnn_r50_fpn_1x | 96.96 | 1344*1344
| mask_rcnn_r50_vd_fpn_2x | 97.66 | 1344*1344
| ppyolo_r18vd | 5.54 | 320*320
| ppyolo_2x | 56.93 | 608*608
| ttfnet_darknet | 36.17 | 512*512
......@@ -98,6 +98,13 @@ class ConfigPaser {
return false;
}
if (config["image_shape"].IsDefined()) {
image_shape_ = config["image_shape"].as<std::vector<int>>();
} else {
std::cerr << "Please set image_shape." << std::endl;
return false;
}
return true;
}
std::string mode_;
......@@ -107,6 +114,7 @@ class ConfigPaser {
bool with_background_;
YAML::Node preprocess_info_;
std::vector<std::string> label_list_;
std::vector<int> image_shape_;
};
} // namespace PaddleDetection
......
......@@ -18,6 +18,7 @@
#include <vector>
#include <memory>
#include <utility>
#include <ctime>
#include <opencv2/core/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
......@@ -28,6 +29,7 @@
#include "include/preprocess_op.h"
#include "include/config_parser.h"
using namespace paddle_infer;
namespace PaddleDetection {
// Object Detection Result
......@@ -54,12 +56,15 @@ cv::Mat VisualizeResult(const cv::Mat& img,
class ObjectDetector {
public:
explicit ObjectDetector(const std::string& model_dir, bool use_gpu = false,
const std::string& run_mode = "fluid") {
explicit ObjectDetector(const std::string& model_dir,
bool use_gpu=false,
const std::string& run_mode="fluid",
const int gpu_id=0) {
config_.load_config(model_dir);
threshold_ = config_.draw_threshold_;
preprocessor_.Init(config_.preprocess_info_, config_.arch_);
LoadModel(model_dir, use_gpu, config_.min_subgraph_size_, 1, run_mode);
image_shape_ = config_.image_shape_;
preprocessor_.Init(config_.preprocess_info_, image_shape_);
LoadModel(model_dir, use_gpu, config_.min_subgraph_size_, 1, run_mode, gpu_id);
}
// Load Paddle inference model
......@@ -68,12 +73,16 @@ class ObjectDetector {
bool use_gpu,
const int min_subgraph_size,
const int batch_size = 1,
const std::string& run_mode = "fluid");
const std::string& run_mode = "fluid",
const int gpu_id=0);
// Run predictor
void Predict(
const cv::Mat& img,
std::vector<ObjectResult>* result);
void Predict(const cv::Mat& im,
const double threshold = 0.5,
const int warmup = 0,
const int repeats = 1,
const bool run_benchmark = false,
std::vector<ObjectResult>* result = nullptr);
// Get Model Label list
const std::vector<std::string>& GetLabelList() const {
......@@ -88,12 +97,13 @@ class ObjectDetector {
const cv::Mat& raw_mat,
std::vector<ObjectResult>* result);
std::unique_ptr<paddle::PaddlePredictor> predictor_;
std::shared_ptr<Predictor> predictor_;
Preprocessor preprocessor_;
ImageBlob inputs_;
std::vector<float> output_data_;
float threshold_;
ConfigPaser config_;
std::vector<int> image_shape_;
};
} // namespace PaddleDetection
......@@ -14,6 +14,7 @@
#pragma once
#include <glog/logging.h>
#include <yaml-cpp/yaml.h>
#include <vector>
......@@ -31,29 +32,36 @@ namespace PaddleDetection {
// Object for storing all preprocessed data
class ImageBlob {
public:
// Original image width and height
std::vector<int> ori_im_size_;
// image width and height
std::vector<float> im_shape_;
// Buffer for image data after preprocessing
std::vector<float> im_data_;
// Original image width, height, shrink in float format
std::vector<float> ori_im_size_f_;
// input image width, height
std::vector<int> input_shape_;
// Evaluation image width and height
std::vector<float> eval_im_size_f_;
//std::vector<float> eval_im_size_f_;
// Scale factor for image size to origin image size
std::vector<float> scale_factor_;
};
// Abstraction of preprocessing opration class
class PreprocessOp {
public:
virtual void Init(const YAML::Node& item, const std::string& arch) = 0;
virtual void Init(const YAML::Node& item, const std::vector<int> image_shape) = 0;
virtual void Run(cv::Mat* im, ImageBlob* data) = 0;
};
class InitInfo : public PreprocessOp{
public:
virtual void Init(const YAML::Node& item, const std::vector<int> image_shape) {}
virtual void Run(cv::Mat* im, ImageBlob* data);
};
class Normalize : public PreprocessOp {
public:
virtual void Init(const YAML::Node& item, const std::string& arch) {
virtual void Init(const YAML::Node& item, const std::vector<int> image_shape) {
mean_ = item["mean"].as<std::vector<float>>();
scale_ = item["std"].as<std::vector<float>>();
is_channel_first_ = item["is_channel_first"].as<bool>();
is_scale_ = item["is_scale"].as<bool>();
}
......@@ -61,36 +69,28 @@ class Normalize : public PreprocessOp {
private:
// CHW or HWC
bool is_channel_first_;
bool is_scale_;
std::vector<float> mean_;
std::vector<float> scale_;
bool is_scale_;
};
class Permute : public PreprocessOp {
public:
virtual void Init(const YAML::Node& item, const std::string& arch) {
to_bgr_ = item["to_bgr"].as<bool>();
is_channel_first_ = item["channel_first"].as<bool>();
}
virtual void Init(const YAML::Node& item, const std::vector<int> image_shape) {}
virtual void Run(cv::Mat* im, ImageBlob* data);
private:
// RGB to BGR
bool to_bgr_;
// CHW or HWC
bool is_channel_first_;
};
class Resize : public PreprocessOp {
public:
virtual void Init(const YAML::Node& item, const std::string& arch) {
arch_ = arch;
virtual void Init(const YAML::Node& item, const std::vector<int> image_shape) {
interp_ = item["interp"].as<int>();
max_size_ = item["max_size"].as<int>();
target_size_ = item["target_size"].as<int>();
image_shape_ = item["image_shape"].as<std::vector<int>>();
//max_size_ = item["target_size"].as<int>();
keep_ratio_ = item["keep_ratio"].as<bool>();
target_size_ = item["target_size"].as<std::vector<int>>();
if (item["keep_ratio"]) {
input_shape_ = image_shape;
}
}
// Compute best resize scale for x-dimension, y-dimension
......@@ -99,17 +99,16 @@ class Resize : public PreprocessOp {
virtual void Run(cv::Mat* im, ImageBlob* data);
private:
std::string arch_;
int interp_;
int max_size_;
int target_size_;
std::vector<int> image_shape_;
bool keep_ratio_;
std::vector<int> target_size_;
std::vector<int> input_shape_;
};
// Models with FPN need input shape % stride == 0
class PadStride : public PreprocessOp {
public:
virtual void Init(const YAML::Node& item, const std::string& arch) {
virtual void Init(const YAML::Node& item, const std::vector<int> image_shape) {
stride_ = item["stride"].as<int>();
}
......@@ -121,23 +120,25 @@ class PadStride : public PreprocessOp {
class Preprocessor {
public:
void Init(const YAML::Node& config_node, const std::string& arch) {
arch_ = arch;
void Init(const YAML::Node& config_node, const std::vector<int> image_shape) {
// initialize image info at first
ops_["InitInfo"] = std::make_shared<InitInfo>();
for (const auto& item : config_node) {
auto op_name = item["type"].as<std::string>();
ops_[op_name] = CreateOp(op_name);
ops_[op_name]->Init(item, arch);
ops_[op_name]->Init(item, image_shape);
}
}
std::shared_ptr<PreprocessOp> CreateOp(const std::string& name) {
if (name == "Resize") {
if (name == "ResizeOp") {
return std::make_shared<Resize>();
} else if (name == "Permute") {
} else if (name == "PermuteOp") {
return std::make_shared<Permute>();
} else if (name == "Normalize") {
} else if (name == "NormalizeImageOp") {
return std::make_shared<Normalize>();
} else if (name == "PadStride") {
} else if (name == "PadBatchOp") {
return std::make_shared<PadStride>();
}
return nullptr;
......@@ -149,8 +150,8 @@ class Preprocessor {
static const std::vector<std::string> RUN_ORDER;
private:
std::string arch_;
std::unordered_map<std::string, std::shared_ptr<PreprocessOp>> ops_;
};
} // namespace PaddleDetection
# download pre-compiled opencv lib
OPENCV_URL=https://paddleseg.bj.bcebos.com/deploy/docker/opencv3gcc4.8.tar.bz2
if [ ! -d "./deps/opencv3gcc4.8" ]; then
mkdir -p deps
cd deps
wget -c ${OPENCV_URL}
tar xvfj opencv3gcc4.8.tar.bz2
rm -rf opencv3gcc4.8.tar.bz2
cd ..
fi
# 是否使用GPU(即是否使用 CUDA)
WITH_GPU=OFF
# 使用MKL or openblas
# 是否使用MKL or openblas,TX2需要设置为OFF
WITH_MKL=ON
# 是否集成 TensorRT(仅WITH_GPU=ON 有效)
WITH_TENSORRT=OFF
# TensorRT 的路径
TENSORRT_DIR=/path/to/TensorRT/
# TensorRT 的include路径
TENSORRT_INC_DIR=/path/to/tensorrt/lib
# TensorRT 的lib路径
TENSORRT_LIB_DIR=/path/to/tensorrt/include
# Paddle 预测库路径
PADDLE_DIR=/path/to/fluid_inference/
# Paddle 的预测库是否使用静态库来编译
# 使用TensorRT时,Paddle的预测库通常为动态库
WITH_STATIC_LIB=OFF
# CUDA 的 lib 路径
CUDA_LIB=/path/to/cuda/lib/
CUDA_LIB=/path/to/cuda/lib
# CUDNN 的 lib 路径
CUDNN_LIB=/path/to/cudnn/lib/
CUDNN_LIB=/path/to/cudnn/lib
# OPENCV 路径, 如果使用自带预编译版本可不修改
sh $(pwd)/scripts/bootstrap.sh # 下载预编译版本的opencv
OPENCV_DIR=$(pwd)/deps/opencv3gcc4.8/
MACHINE_TYPE=`uname -m`
echo "MACHINE_TYPE: "${MACHINE_TYPE}
if [ "$MACHINE_TYPE" = "x86_64" ]
then
echo "set OPENCV_DIR for x86_64"
# linux系统通过以下命令下载预编译的opencv
mkdir -p $(pwd)/deps && cd $(pwd)/deps
wget -c https://bj.bcebos.com/paddleseg/deploy/opencv3.4.6gcc4.8ffmpeg.tar.gz2
tar xvfj opencv3.4.6gcc4.8ffmpeg.tar.gz2 && cd ..
# set OPENCV_DIR
OPENCV_DIR=$(pwd)/deps/opencv3.4.6gcc4.8ffmpeg/
elif [ "$MACHINE_TYPE" = "aarch64" ]
then
echo "set OPENCV_DIR for aarch64"
# TX2平台通过以下命令下载预编译的opencv
mkdir -p $(pwd)/deps && cd $(pwd)/deps
wget -c https://paddlemodels.bj.bcebos.com/TX2_JetPack4.3_opencv_3.4.10_gcc7.5.0.zip
unzip TX2_JetPack4.3_opencv_3.4.10_gcc7.5.0.zip && cd ..
# set OPENCV_DIR
OPENCV_DIR=$(pwd)/deps/TX2_JetPack4.3_opencv_3.4.10_gcc7.5.0/
else
echo "Please set OPENCV_DIR manually"
fi
echo "OPENCV_DIR: "$OPENCV_DIR
# 以下无需改动
rm -rf build
......@@ -28,10 +67,13 @@ cmake .. \
-DWITH_GPU=${WITH_GPU} \
-DWITH_MKL=${WITH_MKL} \
-DWITH_TENSORRT=${WITH_TENSORRT} \
-DTENSORRT_DIR=${TENSORRT_DIR} \
-DTENSORRT_LIB_DIR=${TENSORRT_LIB_DIR} \
-DTENSORRT_INC_DIR=${TENSORRT_INC_DIR} \
-DPADDLE_DIR=${PADDLE_DIR} \
-DWITH_STATIC_LIB=${WITH_STATIC_LIB} \
-DCUDA_LIB=${CUDA_LIB} \
-DCUDNN_LIB=${CUDNN_LIB} \
-DOPENCV_DIR=${OPENCV_DIR}
make
echo "make finished!"
......@@ -17,6 +17,16 @@
#include <iostream>
#include <string>
#include <vector>
#include <sys/types.h>
#include <sys/stat.h>
#ifdef _WIN32
#include <direct.h>
#include <io.h>
#elif LINUX
#include <stdarg.h>
#include <sys/stat.h>
#endif
#include "include/object_detector.h"
......@@ -25,13 +35,64 @@ DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_string(image_path, "", "Path of input image");
DEFINE_string(video_path, "", "Path of input video");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
DEFINE_string(run_mode, "fluid", "mode of running(fluid/trt_fp32/trt_fp16)");
DEFINE_bool(use_camera, false, "Use camera or not");
DEFINE_string(run_mode, "fluid", "Mode of running(fluid/trt_fp32/trt_fp16)");
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute");
DEFINE_int32(camera_id, -1, "Device id of camera to predict");
DEFINE_bool(run_benchmark, false, "Whether to predict a image_file repeatedly for benchmark");
DEFINE_double(threshold, 0.5, "Threshold of score.");
DEFINE_string(output_dir, "output", "Directory of output visualization files.");
static std::string DirName(const std::string &filepath) {
auto pos = filepath.rfind(OS_PATH_SEP);
if (pos == std::string::npos) {
return "";
}
return filepath.substr(0, pos);
}
static bool PathExists(const std::string& path){
#ifdef _WIN32
struct _stat buffer;
return (_stat(path.c_str(), &buffer) == 0);
#else
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0);
#endif // !_WIN32
}
static void MkDir(const std::string& path) {
if (PathExists(path)) return;
int ret = 0;
#ifdef _WIN32
ret = _mkdir(path.c_str());
#else
ret = mkdir(path.c_str(), 0755);
#endif // !_WIN32
if (ret != 0) {
std::string path_error(path);
path_error += " mkdir failed!";
throw std::runtime_error(path_error);
}
}
static void MkDirs(const std::string& path) {
if (path.empty()) return;
if (PathExists(path)) return;
MkDirs(DirName(path));
MkDir(path);
}
void PredictVideo(const std::string& video_path,
PaddleDetection::ObjectDetector* det) {
// Open video
cv::VideoCapture capture;
if (FLAGS_camera_id != -1){
capture.open(FLAGS_camera_id);
}else{
capture.open(video_path.c_str());
}
if (!capture.isOpened()) {
printf("can not open video : %s\n", video_path.c_str());
return;
......@@ -44,9 +105,9 @@ void PredictVideo(const std::string& video_path,
// Create VideoWriter for output
cv::VideoWriter video_out;
std::string video_out_path = "output.avi";
std::string video_out_path = "output.mp4";
video_out.open(video_out_path.c_str(),
CV_FOURCC('M', 'J', 'P', 'G'),
0x00000021,
video_fps,
cv::Size(video_width, video_height),
true);
......@@ -60,28 +121,48 @@ void PredictVideo(const std::string& video_path,
auto colormap = PaddleDetection::GenerateColorMap(labels.size());
// Capture all frames and do inference
cv::Mat frame;
int frame_id = 0;
while (capture.read(frame)) {
if (frame.empty()) {
break;
}
det->Predict(frame, &result);
det->Predict(frame, 0.5, 0, 1, false, &result);
cv::Mat out_im = PaddleDetection::VisualizeResult(
frame, result, labels, colormap);
for (const auto& item : result) {
printf("In frame id %d, we detect: class=%d confidence=%.2f rect=[%d %d %d %d]\n",
frame_id,
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3]);
}
video_out.write(out_im);
frame_id += 1;
}
capture.release();
video_out.release();
}
void PredictImage(const std::string& image_path,
PaddleDetection::ObjectDetector* det) {
const double threshold,
const bool run_benchmark,
PaddleDetection::ObjectDetector* det,
const std::string& output_dir = "output") {
// Open input image as an opencv cv::Mat object
cv::Mat im = cv::imread(image_path, 1);
// Store all detected result
std::vector<PaddleDetection::ObjectResult> result;
det->Predict(im, &result);
if (run_benchmark)
{
det->Predict(im, threshold, 100, 100, run_benchmark, &result);
}else
{
det->Predict(im, 0.5, 0, 1, run_benchmark, &result);
for (const auto& item : result) {
printf("class=%d confidence=%.2f rect=[%d %d %d %d]\n",
printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
item.class_id,
item.confidence,
item.rect[0],
......@@ -97,8 +178,14 @@ void PredictImage(const std::string& image_path,
std::vector<int> compression_params;
compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
compression_params.push_back(95);
cv::imwrite("output.jpeg", vis_img, compression_params);
printf("Visualized output saved as output.jpeg\n");
std::string output_path(output_dir);
if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) {
output_path += OS_PATH_SEP;
}
output_path += "output.jpg";
cv::imwrite(output_path, vis_img, compression_params);
printf("Visualized output saved as %s\n", output_path.c_str());
}
}
int main(int argc, char** argv) {
......@@ -115,15 +202,18 @@ int main(int argc, char** argv) {
std::cout << "run_mode should be 'fluid', 'trt_fp32' or 'trt_fp16'.";
return -1;
}
// Load model and create a object detector
PaddleDetection::ObjectDetector det(FLAGS_model_dir, FLAGS_use_gpu,
FLAGS_run_mode);
FLAGS_run_mode, FLAGS_gpu_id);
// Do inference on input video or image
if (!FLAGS_video_path.empty()) {
if (!FLAGS_video_path.empty() || FLAGS_use_camera) {
PredictVideo(FLAGS_video_path, &det);
} else if (!FLAGS_image_path.empty()) {
PredictImage(FLAGS_image_path, &det);
if (!PathExists(FLAGS_output_dir)) {
MkDirs(FLAGS_output_dir);
}
PredictImage(FLAGS_image_path, FLAGS_threshold, FLAGS_run_benchmark, &det, FLAGS_output_dir);
}
return 0;
}
......@@ -11,8 +11,13 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sstream>
// for setprecision
#include <iomanip>
#include "include/object_detector.h"
# include "include/object_detector.h"
using namespace paddle_infer;
namespace PaddleDetection {
......@@ -21,22 +26,24 @@ void ObjectDetector::LoadModel(const std::string& model_dir,
bool use_gpu,
const int min_subgraph_size,
const int batch_size,
const std::string& run_mode) {
paddle::AnalysisConfig config;
std::string prog_file = model_dir + OS_PATH_SEP + "__model__";
std::string params_file = model_dir + OS_PATH_SEP + "__params__";
const std::string& run_mode,
const int gpu_id) {
paddle_infer::Config config;
std::string prog_file = model_dir + OS_PATH_SEP + "model.pdmodel";
std::string params_file = model_dir + OS_PATH_SEP + "model.pdiparams";
config.SetModel(prog_file, params_file);
if (use_gpu) {
config.EnableUseGpu(100, 0);
config.EnableUseGpu(200, gpu_id);
config.SwitchIrOptim(true);
if (run_mode != "fluid") {
auto precision = paddle::AnalysisConfig::Precision::kFloat32;
auto precision = paddle_infer::Config::Precision::kFloat32;
if (run_mode == "trt_fp16") {
precision = paddle::AnalysisConfig::Precision::kHalf;
precision = paddle_infer::Config::Precision::kHalf;
} else if (run_mode == "trt_int8") {
printf("TensorRT int8 mode is not supported now, "
"please use 'trt_fp32' or 'trt_fp16' instead");
} else {
if (run_mode != "trt_32") {
if (run_mode != "trt_fp32") {
printf("run_mode should be 'fluid', 'trt_fp32' or 'trt_fp16'");
}
}
......@@ -52,10 +59,10 @@ void ObjectDetector::LoadModel(const std::string& model_dir,
config.DisableGpu();
}
config.SwitchUseFeedFetchOps(false);
config.SwitchSpecifyInputNames(true);
config.DisableGlogInfo();
// Memory optimization
config.EnableMemoryOptim();
predictor_ = std::move(CreatePaddlePredictor(config));
predictor_ = std::move(CreatePredictor(config));
}
// Visualiztion MaskDetector results
......@@ -70,13 +77,15 @@ cv::Mat VisualizeResult(const cv::Mat& img,
cv::Rect roi = cv::Rect(results[i].rect[0], results[i].rect[2], w, h);
// Configure color and text size
std::string text = lable_list[results[i].class_id];
std::ostringstream oss;
oss << std::setiosflags(std::ios::fixed) << std::setprecision(4);
oss << lable_list[results[i].class_id] << " ";
oss << results[i].confidence;
std::string text = oss.str();
int c1 = colormap[3 * results[i].class_id + 0];
int c2 = colormap[3 * results[i].class_id + 1];
int c3 = colormap[3 * results[i].class_id + 2];
cv::Scalar roi_color = cv::Scalar(c1, c2, c3);
text += " ";
text += std::to_string(static_cast<int>(results[i].confidence * 100)) + "%";
int font_face = cv::FONT_HERSHEY_COMPLEX_SMALL;
double font_scale = 0.5f;
float thickness = 0.5;
......@@ -139,7 +148,7 @@ void ObjectDetector::Postprocess(
int ymax = (output_data_[5 + j * 6] * rh);
int wd = xmax - xmin;
int hd = ymax - ymin;
if (score > threshold_) {
if (score > threshold_ && class_id > -1) {
ObjectResult result_item;
result_item.rect = {xmin, xmax, ymin, ymax};
result_item.class_id = class_id;
......@@ -150,44 +159,78 @@ void ObjectDetector::Postprocess(
}
void ObjectDetector::Predict(const cv::Mat& im,
const double threshold,
const int warmup,
const int repeats,
const bool run_benchmark,
std::vector<ObjectResult>* result) {
// Preprocess image
Preprocess(im);
// Prepare input tensor
auto input_names = predictor_->GetInputNames();
for (const auto& tensor_name : input_names) {
auto in_tensor = predictor_->GetInputTensor(tensor_name);
auto in_tensor = predictor_->GetInputHandle(tensor_name);
if (tensor_name == "image") {
int rh = inputs_.eval_im_size_f_[0];
int rw = inputs_.eval_im_size_f_[1];
int rh = inputs_.input_shape_[0];
int rw = inputs_.input_shape_[1];
in_tensor->Reshape({1, 3, rh, rw});
in_tensor->copy_from_cpu(inputs_.im_data_.data());
} else if (tensor_name == "im_size") {
in_tensor->Reshape({1, 2});
in_tensor->copy_from_cpu(inputs_.ori_im_size_.data());
} else if (tensor_name == "im_info") {
in_tensor->Reshape({1, 3});
in_tensor->copy_from_cpu(inputs_.eval_im_size_f_.data());
in_tensor->CopyFromCpu(inputs_.im_data_.data());
} else if (tensor_name == "im_shape") {
in_tensor->Reshape({1, 3});
in_tensor->copy_from_cpu(inputs_.ori_im_size_f_.data());
in_tensor->Reshape({1, 2});
in_tensor->CopyFromCpu(inputs_.im_shape_.data());
} else if (tensor_name == "scale_factor") {
in_tensor->Reshape({1, 2});
in_tensor->CopyFromCpu(inputs_.scale_factor_.data());
}
}
// Run predictor
predictor_->ZeroCopyRun();
for (int i = 0; i < warmup; i++)
{
predictor_->Run();
// Get output tensor
auto output_names = predictor_->GetOutputNames();
auto out_tensor = predictor_->GetOutputTensor(output_names[0]);
auto out_tensor = predictor_->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = out_tensor->shape();
// Calculate output length
int output_size = 1;
for (int j = 0; j < output_shape.size(); ++j) {
output_size *= output_shape[j];
}
if (output_size < 6) {
std::cerr << "[WARNING] No object detected." << std::endl;
}
output_data_.resize(output_size);
out_tensor->copy_to_cpu(output_data_.data());
out_tensor->CopyToCpu(output_data_.data());
}
std::clock_t start = clock();
for (int i = 0; i < repeats; i++)
{
predictor_->Run();
// Get output tensor
auto output_names = predictor_->GetOutputNames();
auto out_tensor = predictor_->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = out_tensor->shape();
// Calculate output length
int output_size = 1;
for (int j = 0; j < output_shape.size(); ++j) {
output_size *= output_shape[j];
}
if (output_size < 6) {
std::cerr << "[WARNING] No object detected." << std::endl;
}
output_data_.resize(output_size);
out_tensor->CopyToCpu(output_data_.data());
}
std::clock_t end = clock();
float ms = static_cast<float>(end - start) / CLOCKS_PER_SEC / repeats * 1000.;
printf("Inference: %f ms per batch image\n", ms);
// Postprocessing result
if(!run_benchmark) {
Postprocess(im, result);
}
}
std::vector<int> GenerateColorMap(int num_class) {
......
......@@ -19,6 +19,18 @@
namespace PaddleDetection {
void InitInfo::Run(cv::Mat* im, ImageBlob* data) {
data->im_shape_ = {
static_cast<float>(im->rows),
static_cast<float>(im->cols)
};
data->scale_factor_ = {1., 1.};
data->input_shape_ = {
static_cast<int>(im->rows),
static_cast<int>(im->cols)
};
}
void Normalize::Run(cv::Mat* im, ImageBlob* data) {
double e = 1.0;
if (is_scale_) {
......@@ -49,34 +61,34 @@ void Permute::Run(cv::Mat* im, ImageBlob* data) {
}
void Resize::Run(cv::Mat* im, ImageBlob* data) {
data->ori_im_size_ = {
static_cast<int>(im->rows),
static_cast<int>(im->cols)
};
data->ori_im_size_f_ = {
static_cast<float>(im->rows),
static_cast<float>(im->cols),
1.0
};
auto resize_scale = GenerateScale(*im);
cv::resize(
*im, *im, cv::Size(), resize_scale.first, resize_scale.second, interp_);
if (max_size_ != 0 && !image_shape_.empty()) {
data->im_shape_ = {
static_cast<float>(im->rows),
static_cast<float>(im->cols),
};
data->scale_factor_ = {
resize_scale.second,
resize_scale.first,
};
if (keep_ratio_) {
int max_size = input_shape_[1];
// Padding the image with 0 border
cv::copyMakeBorder(
*im,
*im,
0,
max_size_ - im->rows,
max_size - im->rows,
0,
max_size_ - im->cols,
max_size - im->cols,
cv::BORDER_CONSTANT,
cv::Scalar(0));
}
data->eval_im_size_f_ = {
static_cast<float>(im->rows),
static_cast<float>(im->cols),
resize_scale.first
data->input_shape_ = {
static_cast<int>(im->rows),
static_cast<int>(im->cols),
};
}
......@@ -85,23 +97,22 @@ std::pair<float, float> Resize::GenerateScale(const cv::Mat& im) {
int origin_w = im.cols;
int origin_h = im.rows;
if (max_size_ != 0 && (arch_ == "RCNN" || arch_ == "RetinaNet")) {
if (keep_ratio_) {
int im_size_max = std::max(origin_w, origin_h);
int im_size_min = std::min(origin_w, origin_h);
float scale_ratio =
static_cast<float>(target_size_) / static_cast<float>(im_size_min);
if (max_size_ > 0) {
if (round(scale_ratio * im_size_max) > max_size_) {
scale_ratio =
static_cast<float>(max_size_) / static_cast<float>(im_size_max);
}
}
int target_size_max = *std::max_element(target_size_.begin(), target_size_.end());
int target_size_min = *std::min_element(target_size_.begin(), target_size_.end());
float scale_min =
static_cast<float>(target_size_min) / static_cast<float>(im_size_min);
float scale_max =
static_cast<float>(target_size_max) / static_cast<float>(im_size_max);
float scale_ratio = std::min(scale_min, scale_max);
resize_scale = {scale_ratio, scale_ratio};
} else {
resize_scale.first =
static_cast<float>(target_size_) / static_cast<float>(origin_w);
static_cast<float>(target_size_[1]) / static_cast<float>(origin_w);
resize_scale.second =
static_cast<float>(target_size_) / static_cast<float>(origin_h);
static_cast<float>(target_size_[0]) / static_cast<float>(origin_h);
}
return resize_scale;
}
......@@ -124,14 +135,17 @@ void PadStride::Run(cv::Mat* im, ImageBlob* data) {
nw - rw,
cv::BORDER_CONSTANT,
cv::Scalar(0));
(data->eval_im_size_f_)[0] = static_cast<float>(im->rows);
(data->eval_im_size_f_)[1] = static_cast<float>(im->cols);
data->input_shape_ = {
static_cast<int>(im->rows),
static_cast<int>(im->cols),
};
}
// Preprocessor op running order
const std::vector<std::string> Preprocessor::RUN_ORDER = {
"Resize", "Normalize", "PadStride", "Permute"
"InitInfo", "ResizeOp", "NormalizeImageOp", "PadStrideOp", "PermuteOp"
};
void Preprocessor::Run(cv::Mat* im, ImageBlob* data) {
......
......@@ -3,7 +3,7 @@
Python预测可以使用`tools/infer.py`,此种方式依赖PaddleDetection源码;也可以使用本篇教程预测方式,先将模型导出,使用一个独立的文件进行预测。
本篇教程使用AnalysisPredictor对[导出模型](../../docs/advanced_tutorials/deploy/EXPORT_MODEL.md)进行高性能预测。
本篇教程使用AnalysisPredictor对[导出模型](https://github.com/PaddlePaddle/PaddleDetection/blob/master/docs/advanced_tutorials/deploy/EXPORT_MODEL.md)进行高性能预测。
在PaddlePaddle中预测引擎和训练引擎底层有着不同的优化方法, 下面列出了两种不同的预测方式。Executor同时支持训练和预测,AnalysisPredictor则专门针对推理进行了优化,是基于[C++预测库](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/native_infer.html)的Python接口,该引擎可以对模型进行多项图优化,减少不必要的内存拷贝。如果用户在部署已训练模型的过程中对性能有较高的要求,我们提供了独立于PaddleDetection的预测脚本,方便用户直接集成部署。
......@@ -18,7 +18,7 @@ Python预测可以使用`tools/infer.py`,此种方式依赖PaddleDetection源
## 1. 导出预测模型
PaddleDetection在训练过程包括网络的前向和优化器相关参数,而在部署过程中,我们只需要前向参数,具体参考:[导出模型](../../docs/advanced_tutorials/deploy/EXPORT_MODEL.md)
PaddleDetection在训练过程包括网络的前向和优化器相关参数,而在部署过程中,我们只需要前向参数,具体参考:[导出模型](https://github.com/PaddlePaddle/PaddleDetection/blob/master/docs/advanced_tutorials/deploy/EXPORT_MODEL.md)
导出后目录下,包括`__model__``__params__``infer_cfg.yml`三个文件。
......@@ -42,12 +42,14 @@ python deploy/python/infer.py --model_dir=/path/to/models --image_file=/path/to/
| 参数 | 是否必须|含义 |
|-------|-------|----------|
| --model_dir | Yes|上述导出的模型路径 |
| --image_file | Yes |需要预测的图片 |
| --video_file | Yes |需要预测的视频 |
| --image_file | Option |需要预测的图片 |
| --video_file | Option |需要预测的视频 |
| --camera_id | Option | 用来预测的摄像头ID,默认为-1(表示不使用摄像头预测,可设置为:0 - (摄像头数目-1) ),预测过程中在可视化界面按`q`退出输出预测结果到:output/output.mp4|
| --use_gpu |No|是否GPU,默认为False|
| --run_mode |No|使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16)|
| --threshold |No|预测得分的阈值,默认为0.5|
| --output_dir |No|可视化结果保存的根目录,默认为output/|
| --run_benchmark |No|是否运行benchmark,同时需指定--image_file|
说明:
......
......@@ -16,228 +16,143 @@ import os
import argparse
import time
import yaml
import ast
from functools import reduce
from PIL import Image
import cv2
import numpy as np
import paddle
import paddle.fluid as fluid
from preprocess import preprocess, ResizeOp, NormalizeImageOp, PermuteOp, PadStride
from visualize import visualize_box_mask
from paddle.inference import Config
from paddle.inference import create_predictor
def decode_image(im_file, im_info):
"""read rgb image
Args:
im_file (str/np.ndarray): path of image/ np.ndarray read by cv2
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
if isinstance(im_file, str):
with open(im_file, 'rb') as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype='uint8')
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im_info['origin_shape'] = im.shape[:2]
im_info['resize_shape'] = im.shape[:2]
else:
im = im_file
im_info['origin_shape'] = im.shape[:2]
im_info['resize_shape'] = im.shape[:2]
return im, im_info
# Global dictionary
SUPPORT_MODELS = {
'YOLO',
'RCNN',
}
class Resize(object):
"""resize image by target_size and max_size
class Detector(object):
"""
Args:
arch (str): model type
target_size (int): the target size of image
max_size (int): the max size of image
use_cv2 (bool): whether us cv2
image_shape (list): input shape of model
interp (int): method of resize
config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
use_gpu (bool): whether use gpu
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
threshold (float): threshold to reserve the result for output.
"""
def __init__(self,
arch,
target_size,
max_size,
use_cv2=True,
image_shape=None,
interp=cv2.INTER_LINEAR):
self.target_size = target_size
self.max_size = max_size
self.image_shape = image_shape,
self.arch = arch
self.use_cv2 = use_cv2
self.interp = interp
self.scale_set = {'RCNN', 'RetinaNet'}
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im_channel = im.shape[2]
im_scale_x, im_scale_y = self.generate_scale(im)
if self.use_cv2:
im = cv2.resize(
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
else:
resize_w = int(im_scale_x * float(im.shape[1]))
resize_h = int(im_scale_y * float(im.shape[0]))
if self.max_size != 0:
raise TypeError(
'If you set max_size to cap the maximum size of image,'
'please set use_cv2 to True to resize the image.')
im = im.astype('uint8')
im = Image.fromarray(im)
im = im.resize((int(resize_w), int(resize_h)), self.interp)
im = np.array(im)
pred_config,
model_dir,
use_gpu=False,
run_mode='fluid',
threshold=0.5):
self.pred_config = pred_config
self.predictor = load_predictor(
model_dir,
run_mode=run_mode,
min_subgraph_size=self.pred_config.min_subgraph_size,
use_gpu=use_gpu)
# padding im when image_shape fixed by infer_cfg.yml
if self.max_size != 0 and self.image_shape is not None:
padding_im = np.zeros(
(self.max_size, self.max_size, im_channel), dtype=np.float32)
im_h, im_w = im.shape[:2]
padding_im[:im_h, :im_w, :] = im
im = padding_im
def preprocess(self, im):
preprocess_ops = []
for op_info in self.pred_config.preprocess_infos:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
preprocess_ops.append(eval(op_type)(**new_op_info))
im, im_info = preprocess(im, preprocess_ops,
self.pred_config.input_shape)
inputs = create_inputs(im, im_info)
return inputs
if self.arch in self.scale_set:
im_info['scale'] = im_scale_x
im_info['resize_shape'] = im.shape[:2]
return im, im_info
def postprocess(self, np_boxes, np_masks, inputs, threshold=0.5):
# postprocess output of predictor
results = {}
if self.pred_config.arch in ['SSD', 'Face']:
h, w = inputs['im_shape']
scale_y, scale_x = inputs['scale_factor']
w, h = float(h) / scale_y, float(w) / scale_x
np_boxes[:, 2] *= h
np_boxes[:, 3] *= w
np_boxes[:, 4] *= h
np_boxes[:, 5] *= w
expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
np_boxes = np_boxes[expect_boxes, :]
for box in np_boxes:
print('class_id:{:d}, confidence:{:.4f},'
'left_top:[{:.2f},{:.2f}],'
' right_bottom:[{:.2f},{:.2f}]'.format(
int(box[0]), box[1], box[2], box[3], box[4], box[5]))
results['boxes'] = np_boxes
if np_masks is not None:
np_masks = np_masks[expect_boxes, :, :, :]
results['masks'] = np_masks
return results
def generate_scale(self, im):
"""
def predict(self,
image,
threshold=0.5,
warmup=0,
repeats=1,
run_benchmark=False):
'''
Args:
im (np.ndarray): image (np.ndarray)
image (str/np.ndarray): path of image/ np.ndarray read by cv2
threshold (float): threshold of predicted box' score
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
im_c = im.shape[2]
if self.max_size != 0 and self.arch in self.scale_set:
im_size_min = np.min(origin_shape[0:2])
im_size_max = np.max(origin_shape[0:2])
im_scale = float(self.target_size) / float(im_size_min)
if np.round(im_scale * im_size_max) > self.max_size:
im_scale = float(self.max_size) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
im_scale_x = float(self.target_size) / float(origin_shape[1])
im_scale_y = float(self.target_size) / float(origin_shape[0])
return im_scale_x, im_scale_y
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray:
shape:[N, class_num, mask_resolution, mask_resolution]
'''
inputs = self.preprocess(image)
np_boxes, np_masks = None, None
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
class Normalize(object):
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
is_channel_first (bool): if True: image shape is CHW, else: HWC
"""
for i in range(warmup):
self.predictor.run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
if self.pred_config.mask_resolution is not None:
masks_tensor = self.predictor.get_output_handle(output_names[1])
np_masks = masks_tensor.copy_to_cpu()
def __init__(self, mean, std, is_scale=True, is_channel_first=False):
self.mean = mean
self.std = std
self.is_scale = is_scale
self.is_channel_first = is_channel_first
t1 = time.time()
for i in range(repeats):
self.predictor.run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
if self.pred_config.mask_resolution is not None:
masks_tensor = self.predictor.get_output_handle(output_names[1])
np_masks = masks_tensor.copy_to_cpu()
t2 = time.time()
ms = (t2 - t1) * 1000.0 / repeats
print("Inference: {} ms per batch image".format(ms))
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
if self.is_channel_first:
mean = np.array(self.mean)[:, np.newaxis, np.newaxis]
std = np.array(self.std)[:, np.newaxis, np.newaxis]
# do not perform postprocess in benchmark mode
results = []
if not run_benchmark:
if reduce(lambda x, y: x * y, np_boxes.shape) < 6:
print('[WARNNING] No object detected.')
results = {'boxes': np.array([])}
else:
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
if self.is_scale:
im = im / 255.0
im -= mean
im /= std
return im, im_info
class Permute(object):
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def __init__(self, to_bgr=False, channel_first=True):
self.to_bgr = to_bgr
self.channel_first = channel_first
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
if self.channel_first:
im = im.transpose((2, 0, 1)).copy()
if self.to_bgr:
im = im[[2, 1, 0], :, :]
return im, im_info
results = self.postprocess(
np_boxes, np_masks, inputs, threshold=threshold)
class PadStride(object):
""" padding image for model with FPN
Args:
stride (bool): model with FPN need image shape % stride == 0
"""
return results
def __init__(self, stride=0):
self.coarsest_stride = stride
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
coarsest_stride = self.coarsest_stride
if coarsest_stride == 0:
return im
im_c, im_h, im_w = im.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
im_info['resize_shape'] = padding_im.shape[1:]
return padding_im, im_info
def create_inputs(im, im_info, model_arch='YOLO'):
def create_inputs(im, im_info):
"""generate input for different model type
Args:
im (np.ndarray): image (np.ndarray)
......@@ -247,30 +162,19 @@ def create_inputs(im, im_info, model_arch='YOLO'):
inputs (dict): input of model
"""
inputs = {}
inputs['image'] = im
origin_shape = list(im_info['origin_shape'])
resize_shape = list(im_info['resize_shape'])
scale = im_info['scale']
if 'YOLO' in model_arch:
im_size = np.array([origin_shape]).astype('int32')
inputs['im_size'] = im_size
elif 'RetinaNet' in model_arch:
im_info = np.array([resize_shape + [scale]]).astype('float32')
inputs['im_info'] = im_info
elif 'RCNN' in model_arch:
im_info = np.array([resize_shape + [scale]]).astype('float32')
im_shape = np.array([origin_shape + [1.]]).astype('float32')
inputs['im_info'] = im_info
inputs['im_shape'] = im_shape
inputs['image'] = np.array((im, )).astype('float32')
inputs['im_shape'] = np.array((im_info['im_shape'], )).astype('float32')
inputs['scale_factor'] = np.array(
(im_info['scale_factor'], )).astype('float32')
return inputs
class Config():
class PredictConfig():
"""set config of preprocess, postprocess and visualize
Args:
model_dir (str): root path of model.yml
"""
support_models = ['YOLO', 'SSD', 'RetinaNet', 'RCNN', 'Face']
def __init__(self, model_dir):
# parsing Yaml config for Preprocess
......@@ -280,24 +184,32 @@ class Config():
self.check_model(yml_conf)
self.arch = yml_conf['arch']
self.preprocess_infos = yml_conf['Preprocess']
self.use_python_inference = yml_conf['use_python_inference']
self.min_subgraph_size = yml_conf['min_subgraph_size']
self.labels = yml_conf['label_list']
self.mask_resolution = None
if 'mask_resolution' in yml_conf:
self.mask_resolution = yml_conf['mask_resolution']
self.input_shape = yml_conf['image_shape']
self.print_config()
def check_model(self, yml_conf):
"""
Raises:
ValueError: loaded model not in supported model type
"""
for support_model in self.support_models:
for support_model in SUPPORT_MODELS:
if support_model in yml_conf['arch']:
return True
raise ValueError(
"Unsupported arch: {}, expect SSD, YOLO, RetinaNet, RCNN and Face".
format(yml_conf['arch']))
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
'arch'], SUPPORT_MODELS))
def print_config(self):
print('----------- Model Configuration -----------')
print('%s: %s' % ('Model Arch', self.arch))
print('%s: ' % ('Transform Order'))
for op_info in self.preprocess_infos:
print('--%s: %s' % ('transform op', op_info['type']))
print('--------------------------------------------')
def load_predictor(model_dir,
......@@ -321,16 +233,17 @@ def load_predictor(model_dir,
if run_mode == 'trt_int8':
raise ValueError("TensorRT int8 mode is not supported now, "
"please use trt_fp32 or trt_fp16 instead.")
config = Config(
os.path.join(model_dir, 'model.pdmodel'),
os.path.join(model_dir, 'model.pdiparams'))
precision_map = {
'trt_fp32': fluid.core.AnalysisConfig.Precision.Float32,
'trt_fp16': fluid.core.AnalysisConfig.Precision.Half
'trt_int8': Config.Precision.Int8,
'trt_fp32': Config.Precision.Float32,
'trt_fp16': Config.Precision.Half
}
config = fluid.core.AnalysisConfig(
os.path.join(model_dir, '__model__'),
os.path.join(model_dir, '__params__'))
if use_gpu:
# initial GPU memory(M), device ID
config.enable_use_gpu(100, 0)
config.enable_use_gpu(200, 0)
# optimize graph and fuse op
config.switch_ir_optim(True)
else:
......@@ -351,32 +264,23 @@ def load_predictor(model_dir,
config.enable_memory_optim()
# disable feed, fetch OP, needed by zero_copy_run
config.switch_use_feed_fetch_ops(False)
predictor = fluid.core.create_paddle_predictor(config)
predictor = create_predictor(config)
return predictor
def load_executor(model_dir, use_gpu=False):
if use_gpu:
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feed_names, fetch_targets = fluid.io.load_inference_model(
dirname=model_dir,
executor=exe,
model_filename='__model__',
params_filename='__params__')
return exe, program, fetch_targets
def visualize(image_file,
results,
labels,
mask_resolution=14,
output_dir='output/'):
output_dir='output/',
threshold=0.5):
# visualize the predict result
im = visualize_box_mask(
image_file, results, labels, mask_resolution=mask_resolution)
image_file,
results,
labels,
mask_resolution=mask_resolution,
threshold=threshold)
img_name = os.path.split(image_file)[-1]
if not os.path.exists(output_dir):
os.makedirs(output_dir)
......@@ -385,142 +289,45 @@ def visualize(image_file,
print("save result to: " + out_path)
class Detector():
"""
Args:
model_dir (str): root path of __model__, __params__ and infer_cfg.yml
use_gpu (bool): whether use gpu
"""
def print_arguments(args):
print('----------- Running Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------')
def __init__(self,
model_dir,
use_gpu=False,
run_mode='fluid',
threshold=0.5):
self.config = Config(model_dir)
if self.config.use_python_inference:
self.executor, self.program, self.fecth_targets = load_executor(
model_dir, use_gpu=use_gpu)
else:
self.predictor = load_predictor(
model_dir,
run_mode=run_mode,
min_subgraph_size=self.config.min_subgraph_size,
use_gpu=use_gpu)
self.preprocess_ops = []
for op_info in self.config.preprocess_infos:
op_type = op_info.pop('type')
if op_type == 'Resize':
op_info['arch'] = self.config.arch
self.preprocess_ops.append(eval(op_type)(**op_info))
def preprocess(self, im):
# process image by preprocess_ops
im_info = {
'scale': 1.,
'origin_shape': None,
'resize_shape': None,
}
im, im_info = decode_image(im, im_info)
for operator in self.preprocess_ops:
im, im_info = operator(im, im_info)
im = np.array((im, )).astype('float32')
inputs = create_inputs(im, im_info, self.config.arch)
return inputs, im_info
def postprocess(self, np_boxes, np_masks, im_info, threshold=0.5):
# postprocess output of predictor
results = {}
if self.config.arch in ['SSD', 'Face']:
w, h = im_info['origin_shape']
np_boxes[:, 2] *= h
np_boxes[:, 3] *= w
np_boxes[:, 4] *= h
np_boxes[:, 5] *= w
expect_boxes = np_boxes[:, 1] > threshold
np_boxes = np_boxes[expect_boxes, :]
for box in np_boxes:
print('class_id:{:d}, confidence:{:.2f},'
'left_top:[{:.2f},{:.2f}],'
' right_bottom:[{:.2f},{:.2f}]'.format(
int(box[0]), box[1], box[2], box[3], box[4], box[5]))
results['boxes'] = np_boxes
if np_masks is not None:
np_masks = np_masks[expect_boxes, :, :, :]
results['masks'] = np_masks
return results
def predict(self, image, threshold=0.5):
'''
Args:
image (str/np.ndarray): path of image/ np.ndarray read by cv2
threshold (float): threshold of predicted box' score
Returns:
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray:
shape:[N, class_num, mask_resolution, mask_resolution]
'''
inputs, im_info = self.preprocess(image)
np_boxes, np_masks = None, None
if self.config.use_python_inference:
t1 = time.time()
outs = self.executor.run(self.program,
feed=inputs,
fetch_list=self.fecth_targets,
return_numpy=False)
t2 = time.time()
ms = (t2 - t1) * 1000.0
print("Inference: {} ms per batch image".format(ms))
np_boxes = np.array(outs[0])
if self.config.mask_resolution is not None:
np_masks = np.array(outs[1])
def predict_image(detector):
if FLAGS.run_benchmark:
detector.predict(
FLAGS.image_file,
FLAGS.threshold,
warmup=100,
repeats=100,
run_benchmark=True)
else:
input_names = self.predictor.get_input_names()
for i in range(len(inputs)):
input_tensor = self.predictor.get_input_tensor(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
t1 = time.time()
self.predictor.zero_copy_run()
t2 = time.time()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_tensor(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
if self.config.mask_resolution is not None:
masks_tensor = self.predictor.get_output_tensor(output_names[1])
np_masks = masks_tensor.copy_to_cpu()
ms = (t2 - t1) * 1000.0
print("Inference: {} ms per batch image".format(ms))
results = self.postprocess(
np_boxes, np_masks, im_info, threshold=threshold)
return results
def predict_image():
detector = Detector(
FLAGS.model_dir, use_gpu=FLAGS.use_gpu, run_mode=FLAGS.run_mode)
results = detector.predict(FLAGS.image_file, FLAGS.threshold)
visualize(
FLAGS.image_file,
results,
detector.config.labels,
mask_resolution=detector.config.mask_resolution,
output_dir=FLAGS.output_dir)
detector.pred_config.labels,
mask_resolution=detector.pred_config.mask_resolution,
output_dir=FLAGS.output_dir,
threshold=FLAGS.threshold)
def predict_video():
detector = Detector(
FLAGS.model_dir, use_gpu=FLAGS.use_gpu, run_mode=FLAGS.run_mode)
def predict_video(detector, camera_id):
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
video_name = 'output.mp4'
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.split(FLAGS.video_file)[-1]
fps = 30
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
# yapf: disable
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_name = os.path.split(FLAGS.video_file)[-1]
# yapf: enable
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
......@@ -536,33 +343,67 @@ def predict_video():
im = visualize_box_mask(
frame,
results,
detector.config.labels,
mask_resolution=detector.config.mask_resolution)
detector.pred_config.labels,
mask_resolution=detector.pred_config.mask_resolution,
threshold=FLAGS.threshold)
im = np.array(im)
writer.write(im)
if camera_id != -1:
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
writer.release()
def main():
pred_config = PredictConfig(FLAGS.model_dir)
detector = Detector(
pred_config,
FLAGS.model_dir,
use_gpu=FLAGS.use_gpu,
run_mode=FLAGS.run_mode)
# predict from image
if FLAGS.image_file != '':
predict_image(detector)
# predict from video file or camera video stream
if FLAGS.video_file != '' or FLAGS.camera_id != -1:
predict_video(detector, FLAGS.camera_id)
if __name__ == '__main__':
paddle.enable_static()
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--model_dir",
type=str,
default=None,
help=("Directory include:'__model__', '__params__', "
help=("Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py."),
required=True)
parser.add_argument(
"--image_file", type=str, default='', help="Path of image file.")
parser.add_argument(
"--video_file", type=str, default='', help="Path of video file.")
parser.add_argument(
"--camera_id",
type=int,
default=-1,
help="device id of camera to predict.")
parser.add_argument(
"--run_mode",
type=str,
default='fluid',
help="mode of running(fluid/trt_fp32/trt_fp16)")
parser.add_argument(
"--use_gpu", default=False, help="Whether to predict with GPU.")
"--use_gpu",
type=ast.literal_eval,
default=False,
help="Whether to predict with GPU.")
parser.add_argument(
"--run_benchmark",
type=ast.literal_eval,
default=False,
help="Whether to predict a image_file repeatedly for benchmark")
parser.add_argument(
"--threshold", type=float, default=0.5, help="Threshold of score.")
parser.add_argument(
......@@ -572,9 +413,8 @@ if __name__ == '__main__':
help="Directory of output visualization files.")
FLAGS = parser.parse_args()
print_arguments(FLAGS)
if FLAGS.image_file != '' and FLAGS.video_file != '':
assert "Cannot predict image and video at the same time"
if FLAGS.image_file != '':
predict_image()
if FLAGS.video_file != '':
predict_video()
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from PIL import Image
import cv2
import numpy as np
def decode_image(im_file, im_info):
"""read rgb image
Args:
im_file (str|np.ndarray): input can be image path or np.ndarray
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
if isinstance(im_file, str):
with open(im_file, 'rb') as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype='uint8')
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
else:
im = im_file
im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
return im, im_info
class ResizeOp(object):
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def __init__(
self,
target_size,
keep_ratio=True,
interp=cv2.INTER_LINEAR, ):
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
self.keep_ratio = keep_ratio
self.interp = interp
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize(
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
im_info['scale_factor'] = np.array(
[im_scale_y, im_scale_x]).astype('float32')
# padding im when image_shape fixed by infer_cfg.yml
if self.keep_ratio:
max_size = im_info['input_shape'][1]
padding_im = np.zeros(
(max_size, max_size, im_channel), dtype=np.float32)
im_h, im_w = im.shape[:2]
padding_im[:im_h, :im_w, :] = im
im = padding_im
return im, im_info
def generate_scale(self, im):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
im_c = im.shape[2]
if self.keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(self.target_size)
target_size_max = np.max(self.target_size)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = self.target_size
im_scale_y = resize_h / float(origin_shape[0])
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
class NormalizeImageOp(object):
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
is_channel_first (bool): if True: image shape is CHW, else: HWC
"""
def __init__(self, mean, std, is_scale=True):
self.mean = mean
self.std = std
self.is_scale = is_scale
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
if self.is_scale:
im = im / 255.0
im -= mean
im /= std
return im, im_info
class PermuteOp(object):
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def __init__(self, ):
super(PermuteOp, self).__init__()
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.transpose((2, 0, 1)).copy()
return im, im_info
class PadStride(object):
""" padding image for model with FPN
Args:
stride (bool): model with FPN need image shape % stride == 0
"""
def __init__(self, stride=0):
self.coarsest_stride = stride
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
coarsest_stride = self.coarsest_stride
if coarsest_stride == 0:
return im
im_c, im_h, im_w = im.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
return padding_im, im_info
def preprocess(im, preprocess_ops, input_shape):
# process image by preprocess_ops
im_info = {
'scale_factor': np.array(
[1., 1.], dtype=np.float32),
'im_shape': None,
'input_shape': input_shape,
}
im, im_info = decode_image(im, im_info)
for operator in preprocess_ops:
im, im_info = operator(im, im_info)
return im, im_info
......@@ -18,18 +18,20 @@ from __future__ import division
import cv2
import numpy as np
from PIL import Image, ImageDraw
from scipy import ndimage
def visualize_box_mask(im, results, labels, mask_resolution=14):
def visualize_box_mask(im, results, labels, mask_resolution=14, threshold=0.5):
"""
Args:
im (str/np.ndarray): path of image/np.ndarray read by cv2
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray:
shape:[N, class_num, mask_resolution, mask_resolution]
labels (list): labels:['class1', ..., 'classn']
mask_resolution (int): shape of a mask is:[mask_resolution, mask_resolution]
threshold (float): Threshold of score.
Returns:
im (PIL.Image.Image): visualized image
"""
......@@ -46,6 +48,14 @@ def visualize_box_mask(im, results, labels, mask_resolution=14):
resolution=mask_resolution)
if 'boxes' in results:
im = draw_box(im, results['boxes'], labels)
if 'segm' in results:
im = draw_segm(
im,
results['segm'],
results['label'],
results['score'],
labels,
threshold=threshold)
return im
......@@ -73,7 +83,7 @@ def get_color_map_list(num_classes):
def expand_boxes(boxes, scale=0.0):
"""
Args:
boxes (np.ndarray): shape:[N,4], N:number of box
boxes (np.ndarray): shape:[N,4], N:number of box,
matix element:[x_min, y_min, x_max, y_max]
scale (float): scale of boxes
Returns:
......@@ -97,7 +107,7 @@ def draw_mask(im, np_boxes, np_masks, labels, resolution=14, threshold=0.5):
"""
Args:
im (PIL.Image.Image): PIL image
np_boxes (np.ndarray): shape:[N,6], N: number of box
np_boxes (np.ndarray): shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
np_masks (np.ndarray): shape:[N, class_num, resolution, resolution]
labels (list): labels:['class1', ..., 'classn']
......@@ -152,7 +162,7 @@ def draw_box(im, np_boxes, labels):
"""
Args:
im (PIL.Image.Image): PIL image
np_boxes (np.ndarray): shape:[N,6], N: number of box
np_boxes (np.ndarray): shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
labels (list): labels:['class1', ..., 'classn']
Returns:
......@@ -180,9 +190,60 @@ def draw_box(im, np_boxes, labels):
fill=color)
# draw label
text = "{} {:.2f}".format(labels[clsid], score)
text = "{} {:.4f}".format(labels[clsid], score)
tw, th = draw.textsize(text)
draw.rectangle(
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
return im
def draw_segm(im,
np_segms,
np_label,
np_score,
labels,
threshold=0.5,
alpha=0.7):
"""
Draw segmentation on image
"""
mask_color_id = 0
w_ratio = .4
color_list = get_color_map_list(len(labels))
im = np.array(im).astype('float32')
clsid2color = {}
np_segms = np_segms.astype(np.uint8)
for i in range(np_segms.shape[0]):
mask, score, clsid = np_segms[i], np_score[i], np_label[i] + 1
if score < threshold:
continue
if clsid not in clsid2color:
clsid2color[clsid] = color_list[clsid]
color_mask = clsid2color[clsid]
for c in range(3):
color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
idx = np.nonzero(mask)
color_mask = np.array(color_mask)
im[idx[0], idx[1], :] *= 1.0 - alpha
im[idx[0], idx[1], :] += alpha * color_mask
sum_x = np.sum(mask, axis=0)
x = np.where(sum_x > 0.5)[0]
sum_y = np.sum(mask, axis=1)
y = np.where(sum_y > 0.5)[0]
x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1]
cv2.rectangle(im, (x0, y0), (x1, y1),
tuple(color_mask.astype('int32').tolist()), 1)
bbox_text = '%s %.2f' % (labels[clsid], score)
t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
cv2.rectangle(im, (x0, y0), (x0 + t_size[0], y0 - t_size[1] - 3),
tuple(color_mask.astype('int32').tolist()), -1)
cv2.putText(
im,
bbox_text, (x0, y0 - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.3, (0, 0, 0),
1,
lineType=cv2.LINE_AA)
return Image.fromarray(im.astype('uint8'))
......@@ -16,17 +16,21 @@ class BaseArch(nn.Layer):
def __init__(self):
super(BaseArch, self).__init__()
def forward(self, data, input_def, mode):
def forward(self, data, input_def, mode, input_tensor=None):
if input_tensor is None:
self.inputs = self.build_inputs(data, input_def)
else:
self.inputs = input_tensor
self.inputs['mode'] = mode
self.model_arch()
if mode == 'train':
out = self.get_loss()
elif mode == 'infer':
out = self.get_pred()
out = self.get_pred(input_tensor is None)
else:
raise "Now, only support train or infer mode!"
out = None
raise "Now, only support train and infer mode!"
return out
def build_inputs(self, data, input_def):
......@@ -43,3 +47,6 @@ class BaseArch(nn.Layer):
def get_pred(self, ):
raise NotImplementedError("Should implement get_pred method!")
def get_export_model(self, input_tensor):
return self.forward(None, None, 'infer', input_tensor)
......@@ -43,13 +43,16 @@ class YOLOv3(BaseArch):
loss = self.yolo_head.get_loss(self.yolo_head_outs, self.inputs)
return loss
def get_pred(self, ):
def get_pred(self, return_numpy=True):
bbox, bbox_num = self.post_process(
self.yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
if return_numpy:
outs = {
"bbox": bbox.numpy(),
"bbox_num": bbox_num.numpy(),
'im_id': self.inputs['im_id'].numpy()
}
else:
outs = [bbox, bbox_num]
return outs
......@@ -358,7 +358,8 @@ class MultiClassNMS(object):
nms_threshold=.5,
normalized=False,
nms_eta=1.0,
background_label=0):
background_label=0,
return_rois_num=True):
super(MultiClassNMS, self).__init__()
self.score_threshold = score_threshold
self.nms_top_k = nms_top_k
......@@ -367,6 +368,7 @@ class MultiClassNMS(object):
self.normalized = normalized
self.nms_eta = nms_eta
self.background_label = background_label
self.return_rois_num = return_rois_num
def __call__(self, bboxes, score):
kwargs = self.__dict__.copy()
......@@ -419,14 +421,10 @@ class YOLOBox(object):
self.clip_bbox = clip_bbox
self.scale_x_y = scale_x_y
def __call__(self, yolo_head_out, anchors, im_shape, scale_factor=None):
def __call__(self, yolo_head_out, anchors, im_shape, scale_factor):
boxes_list = []
scores_list = []
if scale_factor is not None:
origin_shape = im_shape / scale_factor
else:
origin_shape = im_shape
origin_shape = paddle.cast(origin_shape, 'int32')
for i, head_out in enumerate(yolo_head_out):
boxes, scores = ops.yolo_box(head_out, origin_shape, anchors[i],
......
......@@ -29,10 +29,19 @@ import numpy as np
from functools import reduce
__all__ = [
'roi_pool', 'roi_align', 'prior_box', 'anchor_generator',
'generate_proposals', 'iou_similarity', 'box_coder', 'yolo_box',
'multiclass_nms', 'distribute_fpn_proposals', 'collect_fpn_proposals',
'matrix_nms', 'BatchNorm'
'roi_pool',
'roi_align',
'prior_box',
'anchor_generator',
'generate_proposals',
'iou_similarity',
'box_coder',
'yolo_box',
'multiclass_nms',
'distribute_fpn_proposals',
'collect_fpn_proposals',
'matrix_nms',
'BatchNorm',
]
......@@ -663,7 +672,7 @@ def yolo_box(
clip_bbox, 'scale_x_y', scale_x_y)
boxes, scores = core.ops.yolo_box(x, origin_shape, *attrs)
return boxes, scores
else:
boxes = helper.create_variable_for_type_inference(dtype=x.dtype)
scores = helper.create_variable_for_type_inference(dtype=x.dtype)
......@@ -952,6 +961,7 @@ def multiclass_nms(bboxes,
nms_eta=1.,
background_label=0,
return_index=False,
return_rois_num=True,
rois_num=None,
name=None):
"""
......@@ -1054,10 +1064,10 @@ def multiclass_nms(bboxes,
output, index, nms_rois_num = core.ops.multiclass_nms3(bboxes, scores,
rois_num, *attrs)
if return_index:
return output, index, nms_rois_num
else:
return output, nms_rois_num
index = None
return output, nms_rois_num, index
else:
output = helper.create_variable_for_type_inference(dtype=bboxes.dtype)
index = helper.create_variable_for_type_inference(dtype='int')
......@@ -1066,7 +1076,10 @@ def multiclass_nms(bboxes,
if rois_num is not None:
inputs['RoisNum'] = rois_num
nms_rois_num = helper.create_variable_for_type_inference(dtype='int32')
if return_rois_num:
nms_rois_num = helper.create_variable_for_type_inference(
dtype='int32')
outputs['NmsRoisNum'] = nms_rois_num
helper.append_op(
......@@ -1084,14 +1097,12 @@ def multiclass_nms(bboxes,
outputs=outputs)
output.stop_gradient = True
index.stop_gradient = True
if not return_index:
index = None
if not return_rois_num:
nms_rois_num = None
if return_index and rois_num is not None:
return output, index, nms_rois_num
elif return_index and rois_num is None:
return output, index
elif not return_index and rois_num is not None:
return output, nms_rois_num
return output
return output, nms_rois_num, index
def matrix_nms(bboxes,
......
......@@ -18,7 +18,7 @@ class BBoxPostProcess(object):
def __call__(self, head_out, rois, im_shape, scale_factor=None):
bboxes, score = self.decode(head_out, rois, im_shape, scale_factor)
bbox_pred, bbox_num = self.nms(bboxes, score)
bbox_pred, bbox_num, _ = self.nms(bboxes, score)
return bbox_pred, bbox_num
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os, sys
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
if parent_path not in sys.path:
sys.path.append(parent_path)
# ignore numba warning
import warnings
warnings.filterwarnings('ignore')
import glob
import numpy as np
from PIL import Image
import paddle
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser
from ppdet.utils.checkpoint import load_weight
from export_utils import dump_infer_config
from paddle.jit import to_static
import paddle.nn as nn
from paddle.static import InputSpec
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def parse_args():
parser = ArgsParser()
parser.add_argument(
"--output_dir",
type=str,
default="output_inference",
help="Directory for storing the output model files.")
args = parser.parse_args()
return args
def run(FLAGS, cfg):
# Model
main_arch = cfg.architecture
model = create(cfg.architecture)
inputs_def = cfg['TestReader']['inputs_def']
assert 'image_shape' in inputs_def, 'image_shape must be specified.'
image_shape = inputs_def.get('image_shape')
assert not None in image_shape, 'image_shape should not contain None'
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_dir = os.path.join(FLAGS.output_dir, cfg_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
image_shape = dump_infer_config(cfg,
os.path.join(save_dir, 'infer_cfg.yml'),
image_shape)
class ExportModel(nn.Layer):
def __init__(self, model):
super(ExportModel, self).__init__()
self.model = model
@to_static(input_spec=[
{
'image': InputSpec(
shape=[None] + image_shape, name='image')
},
{
'im_shape': InputSpec(
shape=[None, 2], name='im_shape')
},
{
'scale_factor': InputSpec(
shape=[None, 2], name='scale_factor')
},
])
def forward(self, image, im_shape, scale_factor):
inputs = {}
inputs_tensor = [image, im_shape, scale_factor]
for t in inputs_tensor:
inputs.update(t)
outs = self.model.get_export_model(inputs)
return outs
export_model = ExportModel(model)
# debug for dy2static, remove later
#paddle.jit.set_code_level()
# Init Model
load_weight(export_model.model, cfg.weights)
export_model.eval()
# export config and model
paddle.jit.save(export_model, os.path.join(save_dir, 'model'))
logger.info('Export model to {}'.format(save_dir))
def main():
paddle.set_device("cpu")
FLAGS = parse_args()
cfg = load_config(FLAGS.config)
merge_config(FLAGS.opt)
check_config(cfg)
check_gpu(cfg.use_gpu)
check_version()
run(FLAGS, cfg)
if __name__ == '__main__':
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import yaml
import numpy as np
from collections import OrderedDict
import logging
logger = logging.getLogger(__name__)
import paddle.fluid as fluid
__all__ = ['dump_infer_config', 'save_infer_model']
# Global dictionary
TRT_MIN_SUBGRAPH = {
'YOLO': 3,
'SSD': 3,
'RCNN': 40,
'RetinaNet': 40,
'EfficientDet': 40,
'Face': 3,
'TTFNet': 3,
'FCOS': 3,
'SOLOv2': 60,
}
def parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
preprocess_list = []
anno_file = dataset_cfg.get_anno()
with_background = reader_cfg['with_background']
use_default_label = dataset_cfg.use_default_label
if metric == 'COCO':
from ppdet.utils.coco_eval import get_category_info
else:
raise ValueError("metric only supports COCO, but received {}".format(
metric))
clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label)
label_list = [str(cat) for cat in catid2name.values()]
sample_transforms = reader_cfg['sample_transforms']
for st in sample_transforms[1:]:
for key, value in st.items():
p = {'type': key}
if key == 'ResizeOp':
if value.get('keep_ratio', False):
max_size = max(image_shape[1:])
image_shape = [3, max_size, max_size]
p.update(value)
preprocess_list.append(p)
batch_transforms = reader_cfg.get('batch_transforms', None)
if batch_transforms:
methods = [list(bt.keys())[0] for bt in batch_transforms]
for bt in batch_transforms:
for key, value in bt.items():
if key == 'PadBatch':
preprocess_list.append({'type': 'PadStride'})
preprocess_list[-1].update({
'stride': value['pad_to_stride']
})
break
return with_background, preprocess_list, label_list, image_shape
def dump_infer_config(config, path, image_shape):
arch_state = False
from ppdet.core.config.yaml_helpers import setup_orderdict
setup_orderdict()
infer_cfg = OrderedDict({
'mode': 'fluid',
'draw_threshold': 0.5,
'metric': config['metric'],
'image_shape': image_shape
})
infer_arch = config['architecture']
for arch, min_subgraph_size in TRT_MIN_SUBGRAPH.items():
if arch in infer_arch:
infer_cfg['arch'] = arch
infer_cfg['min_subgraph_size'] = min_subgraph_size
arch_state = True
break
if not arch_state:
logger.error(
'Architecture: {} is not supported for exporting model now'.format(
infer_arch))
os._exit(0)
if 'Mask' in config['architecture']:
infer_cfg['mask_resolution'] = config['Mask']['mask_resolution']
infer_cfg['with_background'], infer_cfg['Preprocess'], infer_cfg[
'label_list'], image_shape = parse_reader(
config['TestReader'], config['TestDataset'], config['metric'],
infer_cfg['arch'], image_shape)
yaml.dump(infer_cfg, open(path, 'w'))
logger.info("Export inference config file to {}".format(os.path.join(path)))
return image_shape
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册