未验证 提交 5ad431bc 编写于 作者: B Bin Long 提交者: GitHub

Merge pull request #83 from joey12300/release/v0.2.0

Modify inference lib link
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <utils/utils.h>
#include <predictor/seg_predictor.h>
......@@ -9,7 +23,8 @@ int main(int argc, char** argv) {
// 0. parse args
google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_conf.empty() || FLAGS_input_dir.empty()) {
std::cout << "Usage: ./predictor --conf=/config/path/to/your/model --input_dir=/directory/of/your/input/images";
std::cout << "Usage: ./predictor --conf=/config/path/to/your/model "
<< "--input_dir=/directory/of/your/input/images";
return -1;
}
// 1. create a predictor and init it with conf
......@@ -20,7 +35,8 @@ int main(int argc, char** argv) {
}
// 2. get all the images with extension '.jpeg' at input_dir
auto imgs = PaddleSolution::utils::get_directory_images(FLAGS_input_dir, ".jpeg|.jpg");
auto imgs = PaddleSolution::utils::get_directory_images(FLAGS_input_dir,
".jpeg|.jpg");
// 3. predict
predictor.predict(imgs);
return 0;
......
......@@ -6,7 +6,8 @@
## 前置条件
* G++ 4.8.2 ~ 4.9.4
* CMake 3.0+
* CUDA 8.0 / CUDA 9.0 / CUDA 10.0, cudnn 7+ (仅在使用GPU版本的预测库时需要)
* CUDA 9.0 / CUDA 10.0, cudnn 7+ (仅在使用GPU版本的预测库时需要)
* CentOS 7.6, Ubuntu 16.04, Ubuntu 18.04 (均在以上系统验证过)
请确保系统已经安装好上述基本软件,**下面所有示例以工作目录为 `/root/projects/`演示**
......@@ -20,17 +21,16 @@
### Step2: 下载PaddlePaddle C++ 预测库 fluid_inference
PaddlePaddle C++ 预测库主要分为CPU版本和GPU版本。其中,针对不同的CUDA版本,GPU版本预测库又分为三个版本预测库:CUDA 8、CUDA 9和CUDA 10版本预测库。以下为各版本C++预测库的下载链接:
PaddlePaddle C++ 预测库主要分为CPU版本和GPU版本。其中,针对不同的CUDA版本,GPU版本预测库又分为两个版本预测库:CUDA 9.0和CUDA 10.0版本预测库。以下为各版本C++预测库的下载链接:
| 版本 | 链接 |
| ---- | ---- |
| CPU版本 | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/latest-cpu-avx-mkl/fluid_inference.tgz) |
| CUDA 8版本 | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/latest-gpu-cuda8-cudnn7-avx-mkl/fluid_inference.tgz) |
| CUDA 9版本 | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/latest-gpu-cuda9-cudnn7-avx-mkl/fluid_inference.tgz) |
| CUDA 10版本 | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/latest-gpu-cuda10-cudnn7-avx-mkl/fluid_inference.tgz) |
| CPU版本 | [fluid_inference.tgz](https://bj.bcebos.com/paddlehub/paddle_inference_lib/fluid_inference_linux_cpu_1.6.1.tgz) |
| CUDA 9.0版本 | [fluid_inference.tgz](https://bj.bcebos.com/paddlehub/paddle_inference_lib/fluid_inference_linux_cuda97_1.6.1.tgz) |
| CUDA 10.0版本 | [fluid_inference.tgz](https://bj.bcebos.com/paddlehub/paddle_inference_lib/fluid_inference_linux_cuda10_1.6.1.tgz) |
针对不同的CPU类型、不同的指令集,官方提供更多可用的预测库版本,目前已经推出1.6版本的预测库具体请参考以下链接:[C++预测库下载列表](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_usage/deploy/inference/build_and_install_lib_cn.html)
针对不同的CPU类型、不同的指令集,官方提供更多可用的预测库版本,目前已经推出1.6版本的预测库。其余版本具体请参考以下链接:[C++预测库下载列表](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_usage/deploy/inference/build_and_install_lib_cn.html)
下载并解压后`/root/projects/fluid_inference`目录包含内容为:
......@@ -63,7 +63,7 @@ make install
### Step4: 编译
`CMake`编译时,涉及到四个编译参数用于指定核心依赖库的路径, 他们的定义如下:(带*表示仅在使用**GPU版本**预测库时指定)
`CMake`编译时,涉及到四个编译参数用于指定核心依赖库的路径, 他们的定义如下:(带*表示仅在使用**GPU版本**预测库时指定,其中CUDA库版本尽量对齐,**使用9.0、10.0版本,不使用9.2、10.1版本CUDA库**
| 参数名 | 含义 |
| ---- | ---- |
......@@ -84,6 +84,7 @@ make
在使用**CPU版本**预测库进行编译时,可执行下列操作。
```shell
cd /root/projects/PaddleSeg/inference
mkdir build && cd build
cmake .. -DWITH_GPU=OFF -DPADDLE_DIR=/root/projects/fluid_inference -DOPENCV_DIR=/root/projects/opencv3/ -DWITH_STATIC_LIB=OFF
make
......
......@@ -5,7 +5,7 @@
## 前置条件
* Visual Studio 2015
* CUDA 8.0/ CUDA 9.0/ CUDA 10.0,cudnn 7+ (仅在使用GPU版本的预测库时需要)
* CUDA 9.0 / CUDA 10.0,cudnn 7+ (仅在使用GPU版本的预测库时需要)
* CMake 3.0+
请确保系统已经安装好上述基本软件,**下面所有示例以工作目录为 `D:\projects`演示**
......@@ -20,14 +20,13 @@
### Step2: 下载PaddlePaddle C++ 预测库 fluid_inference
PaddlePaddle C++ 预测库主要分为两大版本:CPU版本和GPU版本。其中,针对不同的CUDA版本,GPU版本预测库又分为三个版本预测库:CUDA 8、CUDA 9和CUDA 10版本预测库。根据Windows环境,下载相应版本的PaddlePaddle预测库,并解压到`D:\projects\`目录。以下为各版本C++预测库(CUDA 8版本基于1.5版本的预测库,其余均基于1.6版本的预测库)的下载链接:
PaddlePaddle C++ 预测库主要分为两大版本:CPU版本和GPU版本。其中,针对不同的CUDA版本,GPU版本预测库又分为两个版本预测库:CUDA 9.0和CUDA 10.0版本预测库。根据Windows环境,下载相应版本的PaddlePaddle预测库,并解压到`D:\projects\`目录。以下为各版本C++预测库的下载链接:
| 版本 | 链接 |
| ---- | ---- |
| CPU版本 | [fluid_inference_install_dir.zip](https://paddle-wheel.bj.bcebos.com/1.6.0/win-infer/mkl/cpu/fluid_inference_install_dir.zip) |
| CUDA 8版本 | [fluid_inference_install_dir.zip](https://paddle-inference-lib.bj.bcebos.com/1.5.1-win/gpu_mkl_avx_8.0/fluid_inference_install_dir.zip) |
| CUDA 9版本 | [fluid_inference_install_dir.zip](https://paddle-wheel.bj.bcebos.com/1.6.0/win-infer/mkl/post97/fluid_inference_install_dir.zip) |
| CUDA 10版本 | [fluid_inference_install_dir.zip](https://paddle-wheel.bj.bcebos.com/1.6.0/win-infer/mkl/post107/fluid_inference_install_dir.zip) |
| CPU版本 | [fluid_inference_install_dir.zip](https://bj.bcebos.com/paddlehub/paddle_inference_lib/fluid_install_dir_win_cpu_1.6.zip) |
| CUDA 9.0版本 | [fluid_inference_install_dir.zip](https://bj.bcebos.com/paddlehub/paddle_inference_lib/fluid_inference_install_dir_win_cuda9_1.6.1.zip) |
| CUDA 10.0版本 | [fluid_inference_install_dir.zip](https://bj.bcebos.com/paddlehub/paddle_inference_lib/fluid_inference_install_dir_win_cuda10_1.6.1.zip) |
解压后`D:\projects\fluid_inference`目录包含内容为:
```
......@@ -59,31 +58,36 @@ fluid_inference
call "C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\vcvarsall.bat" amd64
```
* CMAKE编译工程 (带*表示仅在使用**GPU版本**预测库时指定)
* PADDLE_DIR: fluid_inference预测库路径
* *CUDA_LIB: CUDA动态库目录, 请根据实际安装情况调整
* OPENCV_DIR: OpenCV解压目录
三个编译参数的含义说明如下(带*表示仅在使用**GPU版本**预测库时指定, 其中CUDA库版本尽量对齐,**使用9.0、10.0版本,不使用9.2、10.1等版本CUDA库**):
在使用**GPU版本**预测库进行编译时,可执行下列操作。
```
| 参数名 | 含义 |
| ---- | ---- |
| *CUDA_LIB | CUDA的库路径 |
| OPENCV_DIR | OpenCV的安装路径 |
| PADDLE_DIR | Paddle预测库的路径 |
在使用**GPU版本**预测库进行编译时,可执行下列操作。**注意**把对应的参数改为你的上述依赖库实际路径:
```bash
# 切换到预测库所在目录
cd /d D:\projects\PaddleSeg\inference\
# 创建构建目录, 重新构建只需要删除该目录即可
mkdir build
cd build
# cmake构建VS项目
D:\projects\PaddleSeg\inference\build> cmake .. -G "Visual Studio 14 2015 Win64" -DWITH_GPU=ON -DPADDLE_DIR=D:\projects\fluid_inference -DCUDA_LIB=D:\projects\cudalib\v8.0\lib\x64 -DOPENCV_DIR=D:\projects\opencv -T host=x64
D:\projects\PaddleSeg\inference\build> cmake .. -G "Visual Studio 14 2015 Win64" -DWITH_GPU=ON -DPADDLE_DIR=D:\projects\fluid_inference -DCUDA_LIB=D:\projects\cudalib\v9.0\lib\x64 -DOPENCV_DIR=D:\projects\opencv -T host=x64
```
在使用**CPU版本**预测库进行编译时,可执行下列操作。
```
```bash
# 切换到预测库所在目录
cd /d D:\projects\PaddleSeg\inference\
# 创建构建目录, 重新构建只需要删除该目录即可
mkdir build
cd build
# cmake构建VS项目
D:\projects\PaddleSeg\inference\build> cmake .. -G "Visual Studio 14 2015 Win64" -DWITH_GPU=ON -DPADDLE_DIR=D:\projects\fluid_inference -DOPENCV_DIR=D:\projects\opencv -T host=x64
D:\projects\PaddleSeg\inference\build> cmake .. -G "Visual Studio 14 2015 Win64" -DWITH_GPU=OFF -DPADDLE_DIR=D:\projects\fluid_inference -DOPENCV_DIR=D:\projects\opencv -T host=x64
```
这里的`cmake`参数`-G`, 表示生成对应的VS版本的工程,可以根据自己的`VS`版本调整,具体请参考[cmake文档](https://cmake.org/cmake/help/v3.15/manual/cmake-generators.7.html)
......
......@@ -6,7 +6,7 @@ Windows 平台下,我们使用`Visual Studio 2015` 和 `Visual Studio 2019 Com
## 前置条件
* Visual Studio 2019
* CUDA 8.0/ CUDA 9.0/ CUDA 10.0,cudnn 7+ (仅在使用GPU版本的预测库时需要)
* CUDA 9.0/ CUDA 10.0,cudnn 7+ (仅在使用GPU版本的预测库时需要)
* CMake 3.0+
请确保系统已经安装好上述基本软件,我们使用的是`VS2019`的社区版。
......@@ -15,7 +15,7 @@ Windows 平台下,我们使用`Visual Studio 2015` 和 `Visual Studio 2019 Com
### Step1: 下载代码
1. 点击下载源代码:[下载地址](https://github.com/PaddlePaddle/PaddleSeg/archive/master.zip)
1. 点击下载源代码:[下载地址](https://github.com/PaddlePaddle/PaddleSeg/archive/release/v0.2.0.zip)
2. 解压,解压后目录重命名为`PaddleSeg`
以下代码目录路径为`D:\projects\PaddleSeg` 为例。
......@@ -23,14 +23,13 @@ Windows 平台下,我们使用`Visual Studio 2015` 和 `Visual Studio 2019 Com
### Step2: 下载PaddlePaddle C++ 预测库 fluid_inference
PaddlePaddle C++ 预测库主要分为两大版本:CPU版本和GPU版本。其中,针对不同的CUDA版本,GPU版本预测库又分为三个版本预测库:CUDA 8、CUDA 9和CUDA 10版本预测库。根据Windows环境,下载相应版本的PaddlePaddle预测库,并解压到`D:\projects\`目录。以下为各版本C++预测库(CUDA 8版本基于1.5版本的预测库,其余均基于1.6版本的预测库)的下载链接:
PaddlePaddle C++ 预测库主要分为两大版本:CPU版本和GPU版本。其中,针对不同的CUDA版本,GPU版本预测库又分为三个版本预测库:CUDA 9.0和CUDA 10.0版本预测库。根据Windows环境,下载相应版本的PaddlePaddle预测库,并解压到`D:\projects\`目录。以下为各版本C++预测库的下载链接:
| 版本 | 链接 |
| ---- | ---- |
| CPU版本 | [fluid_inference_install_dir.zip](https://paddle-wheel.bj.bcebos.com/1.6.0/win-infer/mkl/cpu/fluid_inference_install_dir.zip) |
| CUDA 8版本 | [fluid_inference_install_dir.zip](https://paddle-inference-lib.bj.bcebos.com/1.5.1-win/gpu_mkl_avx_8.0/fluid_inference_install_dir.zip) |
| CUDA 9版本 | [fluid_inference_install_dir.zip](https://paddle-wheel.bj.bcebos.com/1.6.0/win-infer/mkl/post97/fluid_inference_install_dir.zip) |
| CUDA 10版本 | [fluid_inference_install_dir.zip](https://paddle-wheel.bj.bcebos.com/1.6.0/win-infer/mkl/post107/fluid_inference_install_dir.zip) |
| CPU版本 | [fluid_inference_install_dir.zip](https://bj.bcebos.com/paddlehub/paddle_inference_lib/fluid_install_dir_win_cpu_1.6.zip) |
| CUDA 9.0版本 | [fluid_inference_install_dir.zip](https://bj.bcebos.com/paddlehub/paddle_inference_lib/fluid_inference_install_dir_win_cuda9_1.6.1.zip) |
| CUDA 10.0版本 | [fluid_inference_install_dir.zip](https://bj.bcebos.com/paddlehub/paddle_inference_lib/fluid_inference_install_dir_win_cuda10_1.6.1.zip) |
解压后`D:\projects\fluid_inference`目录包含内容为:
```
......@@ -68,12 +67,12 @@ fluid_inference
4. 点击`浏览`,分别设置编译选项指定`CUDA`、`OpenCV`、`Paddle预测库`的路径
三个编译参数的含义说明如下(带*表示仅在使用**GPU版本**预测库时指定):
三个编译参数的含义说明如下(带*表示仅在使用**GPU版本**预测库时指定, 其中CUDA库版本尽量对齐,**使用9.0、10.0版本,不使用9.2、10.1等版本CUDA库**):
| 参数名 | 含义 |
| ---- | ---- |
| *CUDA_LIB | cuda的库路径 |
| OPENCV_DIR | OpenCV的安装路径 |
| *CUDA_LIB | CUDA的库路径 |
| OPENCV_DIR | OpenCV的安装路径 |
| PADDLE_DIR | Paddle预测库的路径 |
**注意**在使用CPU版本预测库时,需要把CUDA_LIB的勾去掉。
![step4](https://paddleseg.bj.bcebos.com/inference/vs2019_step5.png)
......@@ -90,7 +89,7 @@ fluid_inference
上述`Visual Studio 2019`编译产出的可执行文件在`out\build\x64-Release`目录下,打开`cmd`,并切换到该目录:
```
cd /d D:\projects\PaddleSeg\inference\out\x64-Release
cd /d D:\projects\PaddleSeg\inference\out\build\x64-Release
```
之后执行命令:
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "seg_predictor.h"
#include <unsupported/Eigen/CXX11/Tensor>
#undef min
namespace PaddleSolution {
using std::chrono::duration_cast;
int Predictor::init(const std::string& conf) {
if (!_model_config.load_config(conf)) {
LOG(FATAL) << "Fail to load config file: [" << conf << "]";
......@@ -14,8 +28,9 @@ namespace PaddleSolution {
return -1;
}
_mask.resize(_model_config._resize[0] * _model_config._resize[1]);
_scoremap.resize(_model_config._resize[0] * _model_config._resize[1]);
int res_size = _model_config._resize[0] * _model_config._resize[1];
_mask.resize(res_size);
_scoremap.resize(res_size);
bool use_gpu = _model_config._use_gpu;
const auto& model_dir = _model_config._model_path;
......@@ -33,8 +48,7 @@ namespace PaddleSolution {
config.use_gpu = use_gpu;
config.device = 0;
_main_predictor = paddle::CreatePaddlePredictor(config);
}
else if (_model_config._predictor_mode == "ANALYSIS") {
} else if (_model_config._predictor_mode == "ANALYSIS") {
paddle::AnalysisConfig config;
if (use_gpu) {
config.EnableUseGpu(100, 0);
......@@ -46,25 +60,23 @@ namespace PaddleSolution {
config.SwitchSpecifyInputNames(true);
config.EnableMemoryOptim();
_main_predictor = paddle::CreatePaddlePredictor(config);
}
else {
} else {
return -1;
}
return 0;
}
int Predictor::predict(const std::vector<std::string>& imgs) {
if (_model_config._predictor_mode == "NATIVE") {
return native_predict(imgs);
}
else if (_model_config._predictor_mode == "ANALYSIS") {
} else if (_model_config._predictor_mode == "ANALYSIS") {
return analysis_predict(imgs);
}
return -1;
}
int Predictor::output_mask(const std::string& fname, float* p_out, int length, int* height, int* width) {
int Predictor::output_mask(const std::string& fname, float* p_out,
int length, int* height, int* width) {
int eval_width = _model_config._resize[0];
int eval_height = _model_config._resize[1];
int eval_num_class = _model_config._class_num;
......@@ -77,8 +89,7 @@ namespace PaddleSolution {
seg_out_len << "|" << blob_out_len << "]" << std::endl;
return -1;
}
//post process
// post process
_mask.clear();
_scoremap.clear();
std::vector<int> out_shape{eval_num_class, eval_height, eval_width};
......@@ -92,15 +103,18 @@ namespace PaddleSolution {
cv::imwrite(mask_save_name, mask_png);
cv::Mat scoremap_png = cv::Mat(eval_height, eval_width, CV_8UC1);
scoremap_png.data = _scoremap.data();
std::string scoremap_save_name = nname + std::string("_scoremap.png");
std::string scoremap_save_name = nname
+ std::string("_scoremap.png");
cv::imwrite(scoremap_save_name, scoremap_png);
std::cout << "save mask of [" << fname << "] done" << std::endl;
if (height && width) {
int recover_height = *height;
int recover_width = *width;
cv::Mat recover_png = cv::Mat(recover_height, recover_width, CV_8UC1);
cv::resize(scoremap_png, recover_png, cv::Size(recover_width, recover_height),
cv::Mat recover_png = cv::Mat(recover_height,
recover_width, CV_8UC1);
cv::resize(scoremap_png, recover_png,
cv::Size(recover_width, recover_height),
0, 0, cv::INTER_CUBIC);
std::string recover_name = nname + std::string("_recover.png");
cv::imwrite(recover_name, recover_png);
......@@ -108,8 +122,7 @@ namespace PaddleSolution {
return 0;
}
int Predictor::native_predict(const std::vector<std::string>& imgs)
{
int Predictor::native_predict(const std::vector<std::string>& imgs) {
if (imgs.size() == 0) {
LOG(ERROR) << "No image found";
return -1;
......@@ -120,9 +133,12 @@ namespace PaddleSolution {
int eval_width = _model_config._resize[0];
int eval_height = _model_config._resize[1];
std::size_t total_size = imgs.size();
int default_batch_size = std::min(config_batch_size, (int)total_size);
int batch = total_size / default_batch_size + ((total_size % default_batch_size) != 0);
int batch_buffer_size = default_batch_size * channels * eval_width * eval_height;
int default_batch_size = std::min(config_batch_size,
static_cast<int>(total_size));
int batch = total_size / default_batch_size
+ ((total_size % default_batch_size) != 0);
int batch_buffer_size = default_batch_size * channels
* eval_width * eval_height;
auto& input_buffer = _buffer;
auto& org_width = _org_width;
......@@ -138,7 +154,8 @@ namespace PaddleSolution {
batch_size = total_size % default_batch_size;
}
int real_buffer_size = batch_size * channels * eval_width * eval_height;
int real_buffer_size = batch_size * channels
* eval_width * eval_height;
std::vector<paddle::PaddleTensor> feeds;
input_buffer.resize(real_buffer_size);
org_height.resize(batch_size);
......@@ -151,23 +168,31 @@ namespace PaddleSolution {
int idx = u * default_batch_size + i;
imgs_batch.push_back(imgs[idx]);
}
if (!_preprocessor->batch_process(imgs_batch, input_buffer.data(), org_width.data(), org_height.data())) {
if (!_preprocessor->batch_process(imgs_batch,
input_buffer.data(),
org_width.data(),
org_height.data())) {
return -1;
}
paddle::PaddleTensor im_tensor;
im_tensor.name = "image";
im_tensor.shape = std::vector<int>({ batch_size, channels, eval_height, eval_width });
im_tensor.data.Reset(input_buffer.data(), real_buffer_size * sizeof(float));
im_tensor.shape = std::vector<int>{ batch_size, channels,
eval_height, eval_width };
im_tensor.data.Reset(input_buffer.data(),
real_buffer_size * sizeof(float));
im_tensor.dtype = paddle::PaddleDType::FLOAT32;
feeds.push_back(im_tensor);
_outputs.clear();
auto t1 = std::chrono::high_resolution_clock::now();
if (!_main_predictor->Run(feeds, &_outputs, batch_size)) {
LOG(ERROR) << "Failed: NativePredictor->Run() return false at batch: " << u;
LOG(ERROR) <<
"Failed: NativePredictor->Run() return false at batch: "
<< u;
continue;
}
auto t2 = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
auto duration = duration_cast<std::chrono::microseconds>
(t2 - t1).count();
std::cout << "runtime = " << duration << std::endl;
int out_num = 1;
// print shape of first output tensor for debugging
......@@ -177,15 +202,21 @@ namespace PaddleSolution {
std::cout << _outputs[0].shape[j] << ",";
}
std::cout << ")" << std::endl;
const size_t nums = _outputs.front().data.length() / sizeof(float);
const size_t nums = _outputs.front().data.length()
/ sizeof(float);
if (out_num % batch_size != 0 || out_num != nums) {
LOG(ERROR) << "outputs data size mismatch with shape size.";
return -1;
}
for (int i = 0; i < batch_size; ++i) {
float* output_addr = (float*)(_outputs[0].data.data()) + i * (out_num / batch_size);
output_mask(imgs_batch[i], output_addr, out_num / batch_size, &org_height[i], &org_width[i]);
float* output_addr = reinterpret_cast<float*>(
_outputs[0].data.data())
+ i * (out_num / batch_size);
output_mask(imgs_batch[i], output_addr,
out_num / batch_size,
&org_height[i],
&org_width[i]);
}
}
......@@ -193,7 +224,6 @@ namespace PaddleSolution {
}
int Predictor::analysis_predict(const std::vector<std::string>& imgs) {
if (imgs.size() == 0) {
LOG(ERROR) << "No image found";
return -1;
......@@ -204,9 +234,12 @@ namespace PaddleSolution {
int eval_width = _model_config._resize[0];
int eval_height = _model_config._resize[1];
auto total_size = imgs.size();
int default_batch_size = std::min(config_batch_size, (int)total_size);
int batch = total_size / default_batch_size + ((total_size % default_batch_size) != 0);
int batch_buffer_size = default_batch_size * channels * eval_width * eval_height;
int default_batch_size = std::min(config_batch_size,
static_cast<int>(total_size));
int batch = total_size / default_batch_size
+ ((total_size % default_batch_size) != 0);
int batch_buffer_size = default_batch_size * channels
* eval_width * eval_height;
auto& input_buffer = _buffer;
auto& org_width = _org_width;
......@@ -223,7 +256,8 @@ namespace PaddleSolution {
batch_size = total_size % default_batch_size;
}
int real_buffer_size = batch_size * channels * eval_width * eval_height;
int real_buffer_size = batch_size * channels
* eval_width * eval_height;
std::vector<paddle::PaddleTensor> feeds;
input_buffer.resize(real_buffer_size);
org_height.resize(batch_size);
......@@ -237,21 +271,27 @@ namespace PaddleSolution {
imgs_batch.push_back(imgs[idx]);
}
if (!_preprocessor->batch_process(imgs_batch, input_buffer.data(), org_width.data(), org_height.data())) {
if (!_preprocessor->batch_process(imgs_batch,
input_buffer.data(),
org_width.data(),
org_height.data())) {
return -1;
}
auto im_tensor = _main_predictor->GetInputTensor("image");
im_tensor->Reshape({ batch_size, channels, eval_height, eval_width });
im_tensor->Reshape({ batch_size, channels,
eval_height, eval_width });
im_tensor->copy_from_cpu(input_buffer.data());
auto t1 = std::chrono::high_resolution_clock::now();
_main_predictor->ZeroCopyRun();
auto t2 = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
auto duration = duration_cast<std::chrono::microseconds>
(t2 - t1).count();
std::cout << "runtime = " << duration << std::endl;
auto output_names = _main_predictor->GetOutputNames();
auto output_t = _main_predictor->GetOutputTensor(output_names[0]);
auto output_t = _main_predictor->GetOutputTensor(
output_names[0]);
std::vector<float> out_data;
std::vector<int> output_shape = output_t->shape();
......@@ -266,10 +306,12 @@ namespace PaddleSolution {
out_data.resize(out_num);
output_t->copy_to_cpu(out_data.data());
for (int i = 0; i < batch_size; ++i) {
float* out_addr = out_data.data() + (out_num / batch_size) * i;
output_mask(imgs_batch[i], out_addr, out_num / batch_size, &org_height[i], &org_width[i]);
float* out_addr = out_data.data()
+ (out_num / batch_size) * i;
output_mask(imgs_batch[i], out_addr, out_num / batch_size,
&org_height[i], &org_width[i]);
}
}
return 0;
}
}
} // namespace PaddleSolution
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <yaml-cpp/yaml.h>
#include <memory>
#include <string>
#include <vector>
#include <thread>
#include <chrono>
#include <algorithm>
#include <glog/logging.h>
#include <yaml-cpp/yaml.h>
#include <opencv2/opencv.hpp>
#include <paddle_inference_api.h>
#include <utils/seg_conf_parser.h>
#include <utils/utils.h>
#include <preprocessor/preprocessor.h>
#include <paddle_inference_api.h>
#include <opencv2/opencv.hpp>
#include "utils/seg_conf_parser.h"
#include "utils/utils.h"
#include "preprocessor/preprocessor.h"
namespace PaddleSolution {
class Predictor {
class Predictor {
public:
// init a predictor with a yaml config file
int init(const std::string& conf);
// predict api
int predict(const std::vector<std::string>& imgs);
private:
int output_mask(
const std::string& fname,
float* p_out,
int length,
int* height = NULL,
int* width = NULL);
int output_mask(const std::string& fname, float* p_out, int length,
int* height = NULL, int* width = NULL);
int native_predict(const std::vector<std::string>& imgs);
int analysis_predict(const std::vector<std::string>& imgs);
private:
......@@ -45,5 +55,5 @@ namespace PaddleSolution {
PaddleSolution::PaddleSegModelConfigPaser _model_config;
std::shared_ptr<PaddleSolution::ImagePreProcessor> _preprocessor;
std::unique_ptr<paddle::PaddlePredictor> _main_predictor;
};
}
};
} // namespace PaddleSolution
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
......@@ -7,9 +21,10 @@
namespace PaddleSolution {
std::shared_ptr<ImagePreProcessor> create_processor(const std::string& conf_file) {
auto config = std::make_shared<PaddleSolution::PaddleSegModelConfigPaser>();
std::shared_ptr<ImagePreProcessor> create_processor(
const std::string& conf_file) {
auto config = std::make_shared<PaddleSolution::
PaddleSegModelConfigPaser>();
if (!config->load_config(conf_file)) {
LOG(FATAL) << "fail to laod conf file [" << conf_file << "]";
return nullptr;
......@@ -23,9 +38,9 @@ namespace PaddleSolution {
return p;
}
LOG(FATAL) << "unknown processor_name [" << config->_pre_processor << "]";
LOG(FATAL) << "unknown processor_name [" << config->_pre_processor
<< "]";
return nullptr;
}
}
} // namespace PaddleSolution
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include <string>
......@@ -12,18 +26,19 @@
namespace PaddleSolution {
class ImagePreProcessor {
protected:
ImagePreProcessor() {};
public:
protected:
ImagePreProcessor() {}
public:
virtual ~ImagePreProcessor() {}
virtual bool single_process(const std::string& fname, float* data, int* ori_w, int* ori_h) = 0;
virtual bool batch_process(const std::vector<std::string>& imgs, float* data, int* ori_w, int* ori_h) = 0;
virtual bool single_process(const std::string& fname, float* data,
int* ori_w, int* ori_h) = 0;
virtual bool batch_process(const std::vector<std::string>& imgs,
float* data, int* ori_w, int* ori_h) = 0;
}; // end of class ImagePreProcessor
std::shared_ptr<ImagePreProcessor> create_processor(const std::string &config_file);
std::shared_ptr<ImagePreProcessor> create_processor(
const std::string &config_file);
} // end of namespace paddle_solution
} // namespace PaddleSolution
#include <thread>
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "preprocessor_seg.h"
#include <glog/logging.h>
#include "preprocessor_seg.h"
#include <thread>
namespace PaddleSolution {
bool SegPreProcessor::single_process(const std::string& fname, float* data, int* ori_w, int* ori_h) {
bool SegPreProcessor::single_process(const std::string& fname,
float* data, int* ori_w, int* ori_h) {
cv::Mat im = cv::imread(fname, -1);
if (im.data == nullptr || im.empty()) {
LOG(ERROR) << "Failed to open image: " << fname;
return false;
}
int channels = im.channels();
*ori_w = im.cols;
*ori_h = im.rows;
......@@ -36,7 +51,8 @@ namespace PaddleSolution {
return true;
}
bool SegPreProcessor::batch_process(const std::vector<std::string>& imgs, float* data, int* ori_w, int* ori_h) {
bool SegPreProcessor::batch_process(const std::vector<std::string>& imgs,
float* data, int* ori_w, int* ori_h) {
auto ic = _config->_channels;
auto iw = _config->_resize[0];
auto ih = _config->_resize[1];
......@@ -58,9 +74,9 @@ namespace PaddleSolution {
return true;
}
bool SegPreProcessor::init(std::shared_ptr<PaddleSolution::PaddleSegModelConfigPaser> config) {
bool SegPreProcessor::init(
std::shared_ptr<PaddleSolution::PaddleSegModelConfigPaser> config) {
_config = config;
return true;
}
}
} // namespace PaddleSolution
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include <memory>
#include "preprocessor.h"
#include "utils/utils.h"
namespace PaddleSolution {
class SegPreProcessor : public ImagePreProcessor {
public:
SegPreProcessor() : _config(nullptr) {}
public:
SegPreProcessor() : _config(nullptr){
};
bool init(std::shared_ptr<PaddleSolution::PaddleSegModelConfigPaser> config);
bool init(
std::shared_ptr<PaddleSolution::PaddleSegModelConfigPaser> config);
bool single_process(const std::string &fname, float* data, int* ori_w, int* ori_h);
bool single_process(const std::string &fname, float* data,
int* ori_w, int* ori_h);
bool batch_process(const std::vector<std::string>& imgs, float* data, int* ori_w, int* ori_h);
private:
bool batch_process(const std::vector<std::string>& imgs, float* data,
int* ori_w, int* ori_h);
private:
std::shared_ptr<PaddleSolution::PaddleSegModelConfigPaser> _config;
};
}
} // namespace PaddleSolution
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import sys
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <yaml-cpp/yaml.h>
#include <iostream>
#include <vector>
#include <string>
#include <yaml-cpp/yaml.h>
namespace PaddleSolution {
class PaddleSegModelConfigPaser {
class PaddleSegModelConfigPaser {
public:
PaddleSegModelConfigPaser()
:_class_num(0),
......@@ -56,7 +69,6 @@ namespace PaddleSolution {
}
bool load_config(const std::string& conf_file) {
reset();
YAML::Node config = YAML::LoadFile(conf_file);
......@@ -83,7 +95,8 @@ namespace PaddleSolution {
// 8. get model file_name
_model_file_name = config["DEPLOY"]["MODEL_FILENAME"].as<std::string>();
// 9. get model param file name
_param_file_name = config["DEPLOY"]["PARAMS_FILENAME"].as<std::string>();
_param_file_name =
config["DEPLOY"]["PARAMS_FILENAME"].as<std::string>();
// 10. get pre_processor
_pre_processor = config["DEPLOY"]["PRE_PROCESSOR"].as<std::string>();
// 11. use_gpu
......@@ -98,9 +111,9 @@ namespace PaddleSolution {
}
void debug() const {
std::cout << "EVAL_CROP_SIZE: (" << _resize[0] << ", " << _resize[1] << ")" << std::endl;
std::cout << "EVAL_CROP_SIZE: ("
<< _resize[0] << ", " << _resize[1]
<< ")" << std::endl;
std::cout << "MEAN: [";
for (int i = 0; i < _mean.size(); ++i) {
if (i != _mean.size() - 1) {
......@@ -115,8 +128,7 @@ namespace PaddleSolution {
for (int i = 0; i < _std.size(); ++i) {
if (i != _std.size() - 1) {
std::cout << _std[i] << ", ";
}
else {
} else {
std::cout << _std[i];
}
}
......@@ -127,7 +139,8 @@ namespace PaddleSolution {
std::cout << "DEPLOY.CHANNELS: " << _channels << std::endl;
std::cout << "DEPLOY.MODEL_PATH: " << _model_path << std::endl;
std::cout << "DEPLOY.MODEL_FILENAME: " << _model_file_name << std::endl;
std::cout << "DEPLOY.PARAMS_FILENAME: " << _param_file_name << std::endl;
std::cout << "DEPLOY.PARAMS_FILENAME: "
<< _param_file_name << std::endl;
std::cout << "DEPLOY.PRE_PROCESSOR: " << _pre_processor << std::endl;
std::cout << "DEPLOY.USE_GPU: " << _use_gpu << std::endl;
std::cout << "DEPLOY.PREDICTOR_MODE: " << _predictor_mode << std::endl;
......@@ -160,6 +173,6 @@ namespace PaddleSolution {
std::string _predictor_mode;
// DEPLOY.BATCH_SIZE
int _batch_size;
};
};
}
} // namespace PaddleSolution
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iostream>
......@@ -16,8 +30,9 @@
#endif
namespace PaddleSolution {
namespace utils {
inline std::string path_join(const std::string& dir, const std::string& path) {
namespace utils {
inline std::string path_join(const std::string& dir,
const std::string& path) {
std::string seperator = "/";
#ifdef _WIN32
seperator = "\\";
......@@ -26,8 +41,8 @@ namespace PaddleSolution {
}
#ifndef _WIN32
// scan a directory and get all files with input extensions
inline std::vector<std::string> get_directory_images(const std::string& path, const std::string& exts)
{
inline std::vector<std::string> get_directory_images(
const std::string& path, const std::string& exts) {
std::vector<std::string> imgs;
struct dirent *entry;
DIR *dir = opendir(path.c_str());
......@@ -50,13 +65,15 @@ namespace PaddleSolution {
}
#else
// scan a directory and get all files with input extensions
inline std::vector<std::string> get_directory_images(const std::string& path, const std::string& exts)
{
inline std::vector<std::string> get_directory_images(
const std::string& path, const std::string& exts) {
std::vector<std::string> imgs;
for (const auto& item : std::experimental::filesystem::directory_iterator(path)) {
for (const auto& item :
std::experimental::filesystem::directory_iterator(path)) {
auto suffix = item.path().extension().string();
if (exts.find(suffix) != std::string::npos && suffix.size() > 0) {
auto fullname = path_join(path, item.path().filename().string());
auto fullname = path_join(path,
item.path().filename().string());
imgs.push_back(item.path().string());
}
}
......@@ -65,11 +82,12 @@ namespace PaddleSolution {
#endif
// normalize and HWC_BGR -> CHW_RGB
inline void normalize(cv::Mat& im, float* data, std::vector<float>& fmean, std::vector<float>& fstd) {
inline void normalize(cv::Mat& im, float* data, std::vector<float>& fmean,
std::vector<float>& fstd) {
int rh = im.rows;
int rw = im.cols;
int rc = im.channels();
double normf = (double)1.0 / 255.0;
double normf = static_cast<double>(1.0) / 255.0;
#pragma omp parallel for
for (int h = 0; h < rh; ++h) {
const uchar* ptr = im.ptr<uchar>(h);
......@@ -86,7 +104,8 @@ namespace PaddleSolution {
}
// argmax
inline void argmax(float* out, std::vector<int>& shape, std::vector<uchar>& mask, std::vector<uchar>& scoremap) {
inline void argmax(float* out, std::vector<int>& shape,
std::vector<uchar>& mask, std::vector<uchar>& scoremap) {
int out_img_len = shape[1] * shape[2];
int blob_out_len = out_img_len * shape[0];
/*
......@@ -116,5 +135,5 @@ namespace PaddleSolution {
scoremap[i] = uchar(max_value * 255);
}
}
}
}
} // namespace utils
} // namespace PaddleSolution
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册