未验证 提交 ded24add 编写于 作者: W Walter 提交者: GitHub

Merge pull request #1498 from RainFrost1/whole_chain_c++

更新分类cpp infer输入为yaml文件
......@@ -14,6 +14,11 @@ SET(TENSORRT_DIR "" CACHE PATH "Compile demo with TensorRT")
set(DEMO_NAME "clas_system")
include(external-cmake/yaml-cpp.cmake)
include_directories("${CMAKE_SOURCE_DIR}/")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/ext/yaml-cpp/src/ext-yaml-cpp/include")
link_directories("${CMAKE_CURRENT_BINARY_DIR}/ext/yaml-cpp/lib")
macro(safe_set_static_flag)
foreach(flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
......@@ -61,7 +66,7 @@ if (WIN32)
add_definitions(-DSTATIC_LIB)
endif()
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -o3 -std=c++11")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O3 -std=c++11")
set(CMAKE_STATIC_LIBRARY_PREFIX "")
endif()
message("flags" ${CMAKE_CXX_FLAGS})
......@@ -153,7 +158,7 @@ endif(WITH_STATIC_LIB)
if (NOT WIN32)
set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB}
glog gflags protobuf z xxhash
glog gflags protobuf z xxhash yaml-cpp
)
if(EXISTS "${PADDLE_LIB}/third_party/install/snappystream/lib")
set(DEPS ${DEPS} snappystream)
......@@ -164,7 +169,7 @@ if (NOT WIN32)
else()
set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB}
glog gflags_static libprotobuf xxhash)
glog gflags_static libprotobuf xxhash libyaml-cppmt)
set(DEPS ${DEPS} libcmt shlwapi)
if (EXISTS "${PADDLE_LIB}/third_party/install/snappy/lib")
set(DEPS ${DEPS} snappy)
......@@ -204,6 +209,7 @@ include_directories(${FETCHCONTENT_BASE_DIR}/extern_autolog-src)
AUX_SOURCE_DIRECTORY(./src SRCS)
add_executable(${DEMO_NAME} ${SRCS})
ADD_DEPENDENCIES(${DEMO_NAME} ext-yaml-cpp)
target_link_libraries(${DEMO_NAME} ${DEPS})
......
find_package(Git REQUIRED)
include(ExternalProject)
message("${CMAKE_BUILD_TYPE}")
ExternalProject_Add(
ext-yaml-cpp
URL https://bj.bcebos.com/paddlex/deploy/deps/yaml-cpp.zip
URL_MD5 9542d6de397d1fbd649ed468cb5850e6
CMAKE_ARGS
-DYAML_CPP_BUILD_TESTS=OFF
-DYAML_CPP_BUILD_TOOLS=OFF
-DYAML_CPP_INSTALL=OFF
-DYAML_CPP_BUILD_CONTRIB=OFF
-DMSVC_SHARED_RT=OFF
-DBUILD_SHARED_LIBS=OFF
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=${CMAKE_BINARY_DIR}/ext/yaml-cpp/lib
-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY=${CMAKE_BINARY_DIR}/ext/yaml-cpp/lib
PREFIX "${CMAKE_BINARY_DIR}/ext/yaml-cpp"
# Disable install step
INSTALL_COMMAND ""
LOG_DOWNLOAD ON
LOG_BUILD 1
)
......@@ -28,64 +28,62 @@
#include <fstream>
#include <numeric>
#include "include/cls_config.h"
#include <include/preprocess_op.h>
using namespace paddle_infer;
namespace PaddleClas {
class Classifier {
public:
explicit Classifier(const std::string &model_path,
const std::string &params_path, const bool &use_gpu,
const int &gpu_id, const int &gpu_mem,
const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const bool &use_tensorrt,
const bool &use_fp16, const int &resize_short_size,
const int &crop_size) {
this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id;
this->gpu_mem_ = gpu_mem;
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
this->use_mkldnn_ = use_mkldnn;
this->use_tensorrt_ = use_tensorrt;
this->use_fp16_ = use_fp16;
this->resize_short_size_ = resize_short_size;
this->crop_size_ = crop_size;
LoadModel(model_path, params_path);
}
// Load Paddle inference model
void LoadModel(const std::string &model_path, const std::string &params_path);
// Run predictor
double Run(cv::Mat &img, std::vector<double> *times);
private:
std::shared_ptr<Predictor> predictor_;
bool use_gpu_ = false;
int gpu_id_ = 0;
int gpu_mem_ = 4000;
int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false;
bool use_tensorrt_ = false;
bool use_fp16_ = false;
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
std::vector<float> scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
bool is_scale_ = true;
int resize_short_size_ = 256;
int crop_size_ = 224;
// pre-process
ResizeImg resize_op_;
Normalize normalize_op_;
Permute permute_op_;
CenterCropImg crop_op_;
};
class Classifier {
public:
explicit Classifier(const ClsConfig &config) {
this->use_gpu_ = config.use_gpu;
this->gpu_id_ = config.gpu_id;
this->gpu_mem_ = config.gpu_mem;
this->cpu_math_library_num_threads_ = config.cpu_threads;
this->use_fp16_ = config.use_fp16;
this->use_mkldnn_ = config.use_mkldnn;
this->use_tensorrt_ = config.use_tensorrt;
this->mean_ = config.mean;
this->std_ = config.std;
this->resize_short_size_ = config.resize_short_size;
this->scale_ = config.scale;
this->crop_size_ = config.crop_size;
this->ir_optim_ = config.ir_optim;
LoadModel(config.cls_model_path, config.cls_params_path);
}
// Load Paddle inference model
void LoadModel(const std::string &model_path, const std::string &params_path);
// Run predictor
double Run(cv::Mat &img, std::vector<double> *times);
private:
std::shared_ptr <Predictor> predictor_;
bool use_gpu_ = false;
int gpu_id_ = 0;
int gpu_mem_ = 4000;
int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false;
bool use_tensorrt_ = false;
bool use_fp16_ = false;
bool ir_optim_ = true;
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
std::vector<float> std_ = {0.229f, 0.224f, 0.225f};
float scale_ = 0.00392157;
int resize_short_size_ = 256;
int crop_size_ = 224;
// pre-process
ResizeImg resize_op_;
Normalize normalize_op_;
Permute permute_op_;
CenterCropImg crop_op_;
};
} // namespace PaddleClas
......@@ -14,6 +14,14 @@
#pragma once
#ifdef WIN32
#define OS_PATH_SEP "\\"
#else
#define OS_PATH_SEP "/"
#endif
#include "include/utility.h"
#include "yaml-cpp/yaml.h"
#include <iomanip>
#include <iostream>
#include <map>
......@@ -21,70 +29,85 @@
#include <string>
#include <vector>
#include "include/utility.h"
namespace PaddleClas {
class ClsConfig {
public:
explicit ClsConfig(const std::string &config_file) {
config_map_ = LoadConfig(config_file);
this->use_gpu = bool(stoi(config_map_["use_gpu"]));
this->gpu_id = stoi(config_map_["gpu_id"]);
this->gpu_mem = stoi(config_map_["gpu_mem"]);
this->cpu_threads = stoi(config_map_["cpu_threads"]);
this->use_mkldnn = bool(stoi(config_map_["use_mkldnn"]));
this->use_tensorrt = bool(stoi(config_map_["use_tensorrt"]));
this->use_fp16 = bool(stoi(config_map_["use_fp16"]));
this->cls_model_path.assign(config_map_["cls_model_path"]);
this->cls_params_path.assign(config_map_["cls_params_path"]);
this->resize_short_size = stoi(config_map_["resize_short_size"]);
this->crop_size = stoi(config_map_["crop_size"]);
this->benchmark = bool(stoi(config_map_["benchmark"]));
}
bool use_gpu = false;
int gpu_id = 0;
int gpu_mem = 4000;
int cpu_threads = 1;
bool use_mkldnn = false;
bool use_tensorrt = false;
bool use_fp16 = false;
bool benchmark = false;
std::string cls_model_path;
std::string cls_params_path;
int resize_short_size = 256;
int crop_size = 224;
void PrintConfigInfo();
private:
// Load configuration
std::map<std::string, std::string> LoadConfig(const std::string &config_file);
std::vector<std::string> split(const std::string &str,
const std::string &delim);
std::map<std::string, std::string> config_map_;
};
class ClsConfig {
public:
explicit ClsConfig(const std::string &path) {
ReadYamlConfig(path);
this->infer_imgs =
this->config_file["Global"]["infer_imgs"].as<std::string>();
this->batch_size = this->config_file["Global"]["batch_size"].as<int>();
this->use_gpu = this->config_file["Global"]["use_gpu"].as<bool>();
if (this->config_file["Global"]["gpu_id"].IsDefined())
this->gpu_id = this->config_file["Global"]["gpu_id"].as<int>();
else
this->gpu_id = 0;
this->gpu_mem = this->config_file["Global"]["gpu_mem"].as<int>();
this->cpu_threads =
this->config_file["Global"]["cpu_num_threads"].as<int>();
this->use_mkldnn = this->config_file["Global"]["enable_mkldnn"].as<bool>();
this->use_tensorrt = this->config_file["Global"]["use_tensorrt"].as<bool>();
this->use_fp16 = this->config_file["Global"]["use_fp16"].as<bool>();
this->enable_benchmark =
this->config_file["Global"]["enable_benchmark"].as<bool>();
this->ir_optim = this->config_file["Global"]["ir_optim"].as<bool>();
this->enable_profile =
this->config_file["Global"]["enable_profile"].as<bool>();
this->cls_model_path =
this->config_file["Global"]["inference_model_dir"].as<std::string>() +
OS_PATH_SEP + "inference.pdmodel";
this->cls_params_path =
this->config_file["Global"]["inference_model_dir"].as<std::string>() +
OS_PATH_SEP + "inference.pdiparams";
this->resize_short_size =
this->config_file["PreProcess"]["transform_ops"][0]["ResizeImage"]
["resize_short"]
.as<int>();
this->crop_size =
this->config_file["PreProcess"]["transform_ops"][1]["CropImage"]["size"]
.as<int>();
this->scale = this->config_file["PreProcess"]["transform_ops"][2]
["NormalizeImage"]["scale"]
.as<float>();
this->mean = this->config_file["PreProcess"]["transform_ops"][2]
["NormalizeImage"]["mean"]
.as < std::vector < float >> ();
this->std = this->config_file["PreProcess"]["transform_ops"][2]
["NormalizeImage"]["std"]
.as < std::vector < float >> ();
if (this->config_file["Global"]["benchmark"].IsDefined())
this->benchmark = this->config_file["Global"]["benchmark"].as<bool>();
else
this->benchmark = false;
}
YAML::Node config_file;
bool use_gpu = false;
int gpu_id = 0;
int gpu_mem = 4000;
int cpu_threads = 1;
bool use_mkldnn = false;
bool use_tensorrt = false;
bool use_fp16 = false;
bool benchmark = false;
int batch_size = 1;
bool enable_benchmark = false;
bool ir_optim = true;
bool enable_profile = false;
std::string cls_model_path;
std::string cls_params_path;
std::string infer_imgs;
int resize_short_size = 256;
int crop_size = 224;
float scale = 0.00392157;
std::vector<float> mean = {0.485, 0.456, 0.406};
std::vector<float> std = {0.229, 0.224, 0.225};
void PrintConfigInfo();
void ReadYamlConfig(const std::string &path);
};
} // namespace PaddleClas
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
......@@ -31,26 +31,26 @@ using namespace std;
namespace PaddleClas {
class Normalize {
public:
virtual void Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &scale, const bool is_scale = true);
};
class Normalize {
public:
virtual void Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &std, float &scale);
};
// RGB -> CHW
class Permute {
public:
virtual void Run(const cv::Mat *im, float *data);
};
class CenterCropImg {
public:
virtual void Run(cv::Mat &im, const int crop_size = 224);
};
class ResizeImg {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len);
};
} // namespace PaddleClas
\ No newline at end of file
class Permute {
public:
virtual void Run(const cv::Mat *im, float *data);
};
class CenterCropImg {
public:
virtual void Run(cv::Mat &im, const int crop_size = 224);
};
class ResizeImg {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len);
};
} // namespace PaddleClas
......@@ -32,15 +32,15 @@
namespace PaddleClas {
class Utility {
public:
static std::vector<std::string> ReadDict(const std::string &path);
// template <class ForwardIterator>
// inline static size_t argmax(ForwardIterator first, ForwardIterator last)
// {
// return std::distance(first, std::max_element(first, last));
// }
};
class Utility {
public:
static std::vector <std::string> ReadDict(const std::string &path);
// template <class ForwardIterator>
// inline static size_t argmax(ForwardIterator first, ForwardIterator last)
// {
// return std::distance(first, std::max_element(first, last));
// }
};
} // namespace PaddleClas
\ No newline at end of file
......@@ -143,10 +143,10 @@ tar -xvf paddle_inference.tgz
```
inference/
|--cls_infer.pdmodel
|--cls_infer.pdiparams
|--inference.pdmodel
|--inference.pdiparams
```
**注意**:上述文件中,`cls_infer.pdmodel`文件存储了模型结构信息,`cls_infer.pdiparams`文件存储了模型参数信息。注意两个文件的路径需要与配置文件`tools/config.txt`中的`cls_model_path``cls_params_path`参数对应一致
**注意**:上述文件中,`inference.pdmodel`文件存储了模型结构信息,`inference.pdiparams`文件存储了模型参数信息。模型目录可以随意设置,但是模型名字不能修改
### 2.2 编译PaddleClas C++预测demo
......@@ -183,6 +183,7 @@ cmake .. \
-DCUDA_LIB=${CUDA_LIB_DIR} \
make -j
cd ..
```
上述命令中,
......@@ -200,31 +201,26 @@ make -j
在执行上述命令,编译完成之后,会在当前路径下生成`build`文件夹,其中生成一个名为`clas_system`的可执行文件。
### 运行demo
* 首先修改`tools/config.txt`中对应字段:
* use_gpu:是否使用GPU;
* gpu_id:使用的GPU卡号;
* gpu_mem:显存;
* cpu_math_library_num_threads:底层科学计算库所用线程的数量;
* use_mkldnn:是否使用MKLDNN加速;
* use_tensorrt: 是否使用tensorRT进行加速;
* use_fp16:是否使用半精度浮点数进行计算,该选项仅在use_tensorrt为true时有效;
* cls_model_path:预测模型结构文件路径;
* cls_params_path:预测模型参数文件路径;
* resize_short_size:预处理时图像缩放大小;
* crop_size:预处理时图像裁剪后的大小。
### 2.3 运行demo
#### 2.3.1 设置配置文件
* 然后修改`tools/run.sh`
* `./build/clas_system ./tools/config.txt ./docs/imgs/ILSVRC2012_val_00000666.JPEG`
* 上述命令中分别为:编译得到的可执行文件`clas_system`;运行时的配置文件`config.txt`;待预测的图像。
```shell
cp ../configs/inference_cls.yaml tools/
```
根据[python预测推理](../../docs/zh_CN/inference_deployment/python_deploy.md)`图像分类推理`部分修改好`tools`目录下`inference_cls.yaml`文件。`yaml`文件的参数说明详见[python预测推理](../../docs/zh_CN/inference_deployment/python_deploy.md)
请根据实际存放文件,修改好`Global.infer_imgs``Global.inference_model_dir`等参数。
* 最后执行以下命令,完成对一幅图像的分类。
#### 2.3.2 执行
```shell
sh tools/run.sh
./build/clas_system -c inference_cls.yaml
# or
./build/clas_system -config inference_cls.yaml
```
* 最终屏幕上会输出结果,如下图所示。
最终屏幕上会输出结果,如下图所示。
<div align="center">
<img src="./docs/imgs/cpp_infer_result.png" width="600">
......
......@@ -16,98 +16,97 @@
namespace PaddleClas {
void Classifier::LoadModel(const std::string &model_path,
const std::string &params_path) {
paddle_infer::Config config;
config.SetModel(model_path, params_path);
if (this->use_gpu_) {
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
if (this->use_tensorrt_) {
config.EnableTensorRtEngine(
1 << 20, 1, 3,
this->use_fp16_ ? paddle_infer::Config::Precision::kHalf
: paddle_infer::Config::Precision::kFloat32,
false, false);
void Classifier::LoadModel(const std::string &model_path,
const std::string &params_path) {
paddle_infer::Config config;
config.SetModel(model_path, params_path);
if (this->use_gpu_) {
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
if (this->use_tensorrt_) {
config.EnableTensorRtEngine(
1 << 20, 1, 3,
this->use_fp16_ ? paddle_infer::Config::Precision::kHalf
: paddle_infer::Config::Precision::kFloat32,
false, false);
}
} else {
config.DisableGpu();
if (this->use_mkldnn_) {
config.EnableMKLDNN();
// cache 10 different shapes for mkldnn to avoid memory leak
config.SetMkldnnCacheCapacity(10);
}
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
}
config.SwitchUseFeedFetchOps(false);
// true for multiple input
config.SwitchSpecifyInputNames(true);
config.SwitchIrOptim(this->ir_optim_);
config.EnableMemoryOptim();
config.DisableGlogInfo();
this->predictor_ = CreatePredictor(config);
}
} else {
config.DisableGpu();
if (this->use_mkldnn_) {
config.EnableMKLDNN();
// cache 10 different shapes for mkldnn to avoid memory leak
config.SetMkldnnCacheCapacity(10);
double Classifier::Run(cv::Mat &img, std::vector<double> *times) {
cv::Mat srcimg;
cv::Mat resize_img;
img.copyTo(srcimg);
auto preprocess_start = std::chrono::system_clock::now();
this->resize_op_.Run(img, resize_img, this->resize_short_size_);
this->crop_op_.Run(resize_img, this->crop_size_);
this->normalize_op_.Run(&resize_img, this->mean_, this->std_, this->scale_);
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
this->permute_op_.Run(&resize_img, input.data());
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
auto preprocess_end = std::chrono::system_clock::now();
auto infer_start = std::chrono::system_clock::now();
input_t->CopyFromCpu(input.data());
this->predictor_->Run();
std::vector<float> out_data;
auto output_names = this->predictor_->GetOutputNames();
auto output_t = this->predictor_->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
out_data.resize(out_num);
output_t->CopyToCpu(out_data.data());
auto infer_end = std::chrono::system_clock::now();
auto postprocess_start = std::chrono::system_clock::now();
int maxPosition =
max_element(out_data.begin(), out_data.end()) - out_data.begin();
auto postprocess_end = std::chrono::system_clock::now();
std::chrono::duration<float> preprocess_diff =
preprocess_end - preprocess_start;
times->push_back(double(preprocess_diff.count() * 1000));
std::chrono::duration<float> inference_diff = infer_end - infer_start;
double inference_cost_time = double(inference_diff.count() * 1000);
times->push_back(inference_cost_time);
std::chrono::duration<float> postprocess_diff =
postprocess_end - postprocess_start;
times->push_back(double(postprocess_diff.count() * 1000));
std::cout << "result: " << std::endl;
std::cout << "\tclass id: " << maxPosition << std::endl;
std::cout << std::fixed << std::setprecision(10)
<< "\tscore: " << double(out_data[maxPosition]) << std::endl;
return inference_cost_time;
}
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
}
config.SwitchUseFeedFetchOps(false);
// true for multiple input
config.SwitchSpecifyInputNames(true);
config.SwitchIrOptim(true);
config.EnableMemoryOptim();
config.DisableGlogInfo();
this->predictor_ = CreatePredictor(config);
}
double Classifier::Run(cv::Mat &img, std::vector<double> *times) {
cv::Mat srcimg;
cv::Mat resize_img;
img.copyTo(srcimg);
auto preprocess_start = std::chrono::system_clock::now();
this->resize_op_.Run(img, resize_img, this->resize_short_size_);
this->crop_op_.Run(resize_img, this->crop_size_);
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
this->is_scale_);
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
this->permute_op_.Run(&resize_img, input.data());
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
auto preprocess_end = std::chrono::system_clock::now();
auto infer_start = std::chrono::system_clock::now();
input_t->CopyFromCpu(input.data());
this->predictor_->Run();
std::vector<float> out_data;
auto output_names = this->predictor_->GetOutputNames();
auto output_t = this->predictor_->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
out_data.resize(out_num);
output_t->CopyToCpu(out_data.data());
auto infer_end = std::chrono::system_clock::now();
auto postprocess_start = std::chrono::system_clock::now();
int maxPosition =
max_element(out_data.begin(), out_data.end()) - out_data.begin();
auto postprocess_end = std::chrono::system_clock::now();
std::chrono::duration<float> preprocess_diff =
preprocess_end - preprocess_start;
times->push_back(double(preprocess_diff.count() * 1000));
std::chrono::duration<float> inference_diff = infer_end - infer_start;
double inference_cost_time = double(inference_diff.count() * 1000);
times->push_back(inference_cost_time);
std::chrono::duration<float> postprocess_diff =
postprocess_end - postprocess_start;
times->push_back(double(postprocess_diff.count() * 1000));
std::cout << "result: " << std::endl;
std::cout << "\tclass id: " << maxPosition << std::endl;
std::cout << std::fixed << std::setprecision(10)
<< "\tscore: " << double(out_data[maxPosition]) << std::endl;
return inference_cost_time;
}
} // namespace PaddleClas
......@@ -16,49 +16,20 @@
namespace PaddleClas {
std::vector<std::string> ClsConfig::split(const std::string &str,
const std::string &delim) {
std::vector<std::string> res;
if ("" == str)
return res;
char *strs = new char[str.length() + 1];
std::strcpy(strs, str.c_str());
char *d = new char[delim.length() + 1];
std::strcpy(d, delim.c_str());
char *p = std::strtok(strs, d);
while (p) {
std::string s = p;
res.push_back(s);
p = std::strtok(NULL, d);
}
return res;
}
std::map<std::string, std::string>
ClsConfig::LoadConfig(const std::string &config_path) {
auto config = Utility::ReadDict(config_path);
std::map<std::string, std::string> dict;
for (int i = 0; i < config.size(); i++) {
// pass for empty line or comment
if (config[i].size() <= 1 || config[i][0] == '#') {
continue;
void ClsConfig::PrintConfigInfo() {
std::cout << "=======Paddle Class inference config======" << std::endl;
std::cout << this->config_file << std::endl;
std::cout << "=======End of Paddle Class inference config======" << std::endl;
}
std::vector<std::string> res = split(config[i], " ");
dict[res[0]] = res[1];
}
return dict;
}
void ClsConfig::PrintConfigInfo() {
std::cout << "=======Paddle Class inference config======" << std::endl;
for (auto iter = config_map_.begin(); iter != config_map_.end(); iter++) {
std::cout << iter->first << " : " << iter->second << std::endl;
}
std::cout << "=======End of Paddle Class inference config======" << std::endl;
}
void ClsConfig::ReadYamlConfig(const std::string &path) {
} // namespace PaddleClas
\ No newline at end of file
try {
this->config_file = YAML::LoadFile(path);
} catch (YAML::BadFile &e) {
std::cout << "Something wrong in yaml file, please check yaml file"
<< std::endl;
exit(1);
}
}
}; // namespace PaddleClas
......@@ -27,6 +27,7 @@
#include <numeric>
#include <auto_log/autolog.h>
#include <gflags/gflags.h>
#include <include/cls.h>
#include <include/cls_config.h>
......@@ -34,74 +35,81 @@ using namespace std;
using namespace cv;
using namespace PaddleClas;
int main(int argc, char **argv) {
if (argc < 3) {
std::cerr << "[ERROR] usage: " << argv[0]
<< " configure_filepath image_path\n";
exit(1);
}
ClsConfig config(argv[1]);
config.PrintConfigInfo();
std::string path(argv[2]);
DEFINE_string(config,
"", "Path of yaml file");
DEFINE_string(c,
"", "Path of yaml file");
std::vector<std::string> img_files_list;
if (cv::utils::fs::isDirectory(path)) {
std::vector<cv::String> filenames;
cv::glob(path, filenames);
for (auto f : filenames) {
img_files_list.push_back(f);
int main(int argc, char **argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
std::string yaml_path = "";
if (FLAGS_config == "" && FLAGS_c == "") {
std::cerr << "[ERROR] usage: " << std::endl
<< argv[0] << " -c $yaml_path" << std::endl
<< "or:" << std::endl
<< argv[0] << " -config $yaml_path" << std::endl;
exit(1);
} else if (FLAGS_config != "") {
yaml_path = FLAGS_config;
} else {
yaml_path = FLAGS_c;
}
} else {
img_files_list.push_back(path);
}
std::cout << "img_file_list length: " << img_files_list.size() << std::endl;
Classifier classifier(config.cls_model_path, config.cls_params_path,
config.use_gpu, config.gpu_id, config.gpu_mem,
config.cpu_threads, config.use_mkldnn,
config.use_tensorrt, config.use_fp16,
config.resize_short_size, config.crop_size);
double elapsed_time = 0.0;
std::vector<double> cls_times;
int warmup_iter = img_files_list.size() > 5 ? 5 : 0;
for (int idx = 0; idx < img_files_list.size(); ++idx) {
std::string img_path = img_files_list[idx];
cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: " << img_path
<< "\n";
exit(-1);
ClsConfig config(yaml_path);
config.PrintConfigInfo();
std::string path(config.infer_imgs);
std::vector <std::string> img_files_list;
if (cv::utils::fs::isDirectory(path)) {
std::vector <cv::String> filenames;
cv::glob(path, filenames);
for (auto f : filenames) {
img_files_list.push_back(f);
}
} else {
img_files_list.push_back(path);
}
cv::cvtColor(srcimg, srcimg, cv::COLOR_BGR2RGB);
double run_time = classifier.Run(srcimg, &cls_times);
if (idx >= warmup_iter) {
elapsed_time += run_time;
std::cout << "Current image path: " << img_path << std::endl;
std::cout << "Current time cost: " << run_time << " s, "
<< "average time cost in all: "
<< elapsed_time / (idx + 1 - warmup_iter) << " s." << std::endl;
} else {
std::cout << "Current time cost: " << run_time << " s." << std::endl;
std::cout << "img_file_list length: " << img_files_list.size() << std::endl;
Classifier classifier(config);
double elapsed_time = 0.0;
std::vector<double> cls_times;
int warmup_iter = img_files_list.size() > 5 ? 5 : 0;
for (int idx = 0; idx < img_files_list.size(); ++idx) {
std::string img_path = img_files_list[idx];
cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: " << img_path
<< "\n";
exit(-1);
}
cv::cvtColor(srcimg, srcimg, cv::COLOR_BGR2RGB);
double run_time = classifier.Run(srcimg, &cls_times);
if (idx >= warmup_iter) {
elapsed_time += run_time;
std::cout << "Current image path: " << img_path << std::endl;
std::cout << "Current time cost: " << run_time << " s, "
<< "average time cost in all: "
<< elapsed_time / (idx + 1 - warmup_iter) << " s." << std::endl;
} else {
std::cout << "Current time cost: " << run_time << " s." << std::endl;
}
}
}
std::string presion = "fp32";
std::string presion = "fp32";
if (config.use_fp16)
presion = "fp16";
if (config.benchmark) {
AutoLogger autolog("Classification", config.use_gpu, config.use_tensorrt,
config.use_mkldnn, config.cpu_threads, 1,
"1, 3, 224, 224", presion, cls_times,
img_files_list.size());
autolog.report();
}
return 0;
if (config.use_fp16)
presion = "fp16";
if (config.benchmark) {
AutoLogger autolog("Classification", config.use_gpu, config.use_tensorrt,
config.use_mkldnn, config.cpu_threads, 1,
"1, 3, 224, 224", presion, cls_times,
img_files_list.size());
autolog.report();
}
return 0;
}
......@@ -32,59 +32,57 @@
namespace PaddleClas {
void Permute::Run(const cv::Mat *im, float *data) {
int rh = im->rows;
int rw = im->cols;
int rc = im->channels();
for (int i = 0; i < rc; ++i) {
cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i);
}
}
void Permute::Run(const cv::Mat *im, float *data) {
int rh = im->rows;
int rw = im->cols;
int rc = im->channels();
for (int i = 0; i < rc; ++i) {
cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i);
}
}
void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &scale, const bool is_scale) {
double e = 1.0;
if (is_scale) {
e /= 255.0;
}
(*im).convertTo(*im, CV_32FC3, e);
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) * scale[0];
im->at<cv::Vec3f>(h, w)[1] =
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) * scale[1];
im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) * scale[2];
void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &std, float &scale) {
if (scale) {
(*im).convertTo(*im, CV_32FC3, scale);
}
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) / std[0];
im->at<cv::Vec3f>(h, w)[1] =
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) / std[1];
im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) / std[2];
}
}
}
}
}
void CenterCropImg::Run(cv::Mat &img, const int crop_size) {
int resize_w = img.cols;
int resize_h = img.rows;
int w_start = int((resize_w - crop_size) / 2);
int h_start = int((resize_h - crop_size) / 2);
cv::Rect rect(w_start, h_start, crop_size, crop_size);
img = img(rect);
}
void CenterCropImg::Run(cv::Mat &img, const int crop_size) {
int resize_w = img.cols;
int resize_h = img.rows;
int w_start = int((resize_w - crop_size) / 2);
int h_start = int((resize_h - crop_size) / 2);
cv::Rect rect(w_start, h_start, crop_size, crop_size);
img = img(rect);
}
void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
int resize_short_size) {
int w = img.cols;
int h = img.rows;
void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
int resize_short_size) {
int w = img.cols;
int h = img.rows;
float ratio = 1.f;
if (h < w) {
ratio = float(resize_short_size) / float(h);
} else {
ratio = float(resize_short_size) / float(w);
}
float ratio = 1.f;
if (h < w) {
ratio = float(resize_short_size) / float(h);
} else {
ratio = float(resize_short_size) / float(w);
}
int resize_h = round(float(h) * ratio);
int resize_w = round(float(w) * ratio);
int resize_h = round(float(h) * ratio);
int resize_w = round(float(w) * ratio);
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
}
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
}
} // namespace PaddleClas
\ No newline at end of file
} // namespace PaddleClas
......@@ -18,3 +18,4 @@ cmake .. \
-DCUDA_LIB=${CUDA_LIB_DIR} \
make -j
cd ..
# model load config
use_gpu 0
gpu_id 0
gpu_mem 4000
cpu_threads 10
use_mkldnn 1
use_tensorrt 0
use_fp16 0
# cls config
cls_model_path /PaddleClas/inference/cls_infer.pdmodel
cls_params_path /PaddleClas/inference/cls_infer.pdiparams
resize_short_size 256
crop_size 224
# for log env info
benchmark 0
./build/clas_system ./tools/config.txt ./docs/imgs/ILSVRC2012_val_00000666.JPEG
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册