未验证 提交 2073ce4a 编写于 作者: J Jason 提交者: GitHub

Merge pull request #195 from syyxsxx/develop

[openvino]add hrnet and python support
......@@ -14,3 +14,5 @@
- [模型量化](../docs/deploy/paddlelite/slim/quant.md)
- [模型裁剪](../docs/deploy/paddlelite/slim/prune.md)
- [Android平台](../docs/deploy/paddlelite/android.md)
- [OpenVINO部署](../docs/deploy/openvino/introduction.md)
- [树莓派部署](../docs/deploy/raspberry/Raspberry.md)
\ No newline at end of file
......@@ -29,6 +29,10 @@ using namespace std::chrono; // NOLINT
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
DEFINE_bool(use_trt, false, "Infering with TensorRT");
DEFINE_bool(use_mkl, true, "Infering with MKL");
DEFINE_int32(mkl_thread_num,
omp_get_num_procs(),
"Number of mkl threads");
DEFINE_int32(gpu_id, 0, "GPU card id");
DEFINE_string(key, "", "key of encryption");
DEFINE_string(image, "", "Path of test image file");
......@@ -56,6 +60,8 @@ int main(int argc, char** argv) {
model.Init(FLAGS_model_dir,
FLAGS_use_gpu,
FLAGS_use_trt,
FLAGS_use_mkl,
FLAGS_mkl_thread_num,
FLAGS_gpu_id,
FLAGS_key);
......
......@@ -31,6 +31,10 @@ using namespace std::chrono; // NOLINT
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
DEFINE_bool(use_trt, false, "Infering with TensorRT");
DEFINE_bool(use_mkl, true, "Infering with MKL");
DEFINE_int32(mkl_thread_num,
omp_get_num_procs(),
"Number of mkl threads");
DEFINE_int32(gpu_id, 0, "GPU card id");
DEFINE_string(key, "", "key of encryption");
DEFINE_string(image, "", "Path of test image file");
......@@ -61,6 +65,8 @@ int main(int argc, char** argv) {
model.Init(FLAGS_model_dir,
FLAGS_use_gpu,
FLAGS_use_trt,
FLAGS_use_mkl,
FLAGS_mkl_thread_num,
FLAGS_gpu_id,
FLAGS_key);
int imgs = 1;
......
......@@ -30,6 +30,10 @@ using namespace std::chrono; // NOLINT
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
DEFINE_bool(use_trt, false, "Infering with TensorRT");
DEFINE_bool(use_mkl, true, "Infering with MKL");
DEFINE_int32(mkl_thread_num,
omp_get_num_procs(),
"Number of mkl threads");
DEFINE_int32(gpu_id, 0, "GPU card id");
DEFINE_string(key, "", "key of encryption");
DEFINE_string(image, "", "Path of test image file");
......@@ -58,6 +62,8 @@ int main(int argc, char** argv) {
model.Init(FLAGS_model_dir,
FLAGS_use_gpu,
FLAGS_use_trt,
FLAGS_use_mkl,
FLAGS_mkl_thread_num,
FLAGS_gpu_id,
FLAGS_key);
int imgs = 1;
......
......@@ -35,8 +35,12 @@ using namespace std::chrono; // NOLINT
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
DEFINE_bool(use_trt, false, "Infering with TensorRT");
DEFINE_bool(use_mkl, true, "Infering with MKL");
DEFINE_int32(gpu_id, 0, "GPU card id");
DEFINE_string(key, "", "key of encryption");
DEFINE_int32(mkl_thread_num,
omp_get_num_procs(),
"Number of mkl threads");
DEFINE_bool(use_camera, false, "Infering with Camera");
DEFINE_int32(camera_id, 0, "Camera id");
DEFINE_string(video_path, "", "Path of input video");
......@@ -62,6 +66,8 @@ int main(int argc, char** argv) {
model.Init(FLAGS_model_dir,
FLAGS_use_gpu,
FLAGS_use_trt,
FLAGS_use_mkl,
FLAGS_mkl_thread_num,
FLAGS_gpu_id,
FLAGS_key);
......
......@@ -35,6 +35,7 @@ using namespace std::chrono; // NOLINT
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
DEFINE_bool(use_trt, false, "Infering with TensorRT");
DEFINE_bool(use_mkl, true, "Infering with MKL");
DEFINE_int32(gpu_id, 0, "GPU card id");
DEFINE_bool(use_camera, false, "Infering with Camera");
DEFINE_int32(camera_id, 0, "Camera id");
......@@ -42,6 +43,9 @@ DEFINE_string(video_path, "", "Path of input video");
DEFINE_bool(show_result, false, "show the result of each frame with a window");
DEFINE_bool(save_result, true, "save the result of each frame to a video");
DEFINE_string(key, "", "key of encryption");
DEFINE_int32(mkl_thread_num,
omp_get_num_procs(),
"Number of mkl threads");
DEFINE_string(save_dir, "output", "Path to save visualized image");
DEFINE_double(threshold,
0.5,
......@@ -64,6 +68,8 @@ int main(int argc, char** argv) {
model.Init(FLAGS_model_dir,
FLAGS_use_gpu,
FLAGS_use_trt,
FLAGS_use_mkl,
FLAGS_mkl_thread_num,
FLAGS_gpu_id,
FLAGS_key);
// Open video
......
......@@ -35,8 +35,12 @@ using namespace std::chrono; // NOLINT
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
DEFINE_bool(use_trt, false, "Infering with TensorRT");
DEFINE_bool(use_mkl, true, "Infering with MKL");
DEFINE_int32(gpu_id, 0, "GPU card id");
DEFINE_string(key, "", "key of encryption");
DEFINE_int32(mkl_thread_num,
omp_get_num_procs(),
"Number of mkl threads");
DEFINE_bool(use_camera, false, "Infering with Camera");
DEFINE_int32(camera_id, 0, "Camera id");
DEFINE_string(video_path, "", "Path of input video");
......@@ -62,6 +66,8 @@ int main(int argc, char** argv) {
model.Init(FLAGS_model_dir,
FLAGS_use_gpu,
FLAGS_use_trt,
FLAGS_use_mkl,
FLAGS_mkl_thread_num,
FLAGS_gpu_id,
FLAGS_key);
// Open video
......
......@@ -70,6 +70,8 @@ class Model {
* @param model_dir: the directory which contains model.yml
* @param use_gpu: use gpu or not when infering
* @param use_trt: use Tensor RT or not when infering
* @param use_mkl: use mkl or not when infering
* @param mkl_thread_num: number of threads for mkldnn when infering
* @param gpu_id: the id of gpu when infering with using gpu
* @param key: the key of encryption when using encrypted model
* @param use_ir_optim: use ir optimization when infering
......@@ -77,15 +79,26 @@ class Model {
void Init(const std::string& model_dir,
bool use_gpu = false,
bool use_trt = false,
bool use_mkl = true,
int mkl_thread_num = 4,
int gpu_id = 0,
std::string key = "",
bool use_ir_optim = true) {
create_predictor(model_dir, use_gpu, use_trt, gpu_id, key, use_ir_optim);
create_predictor(
model_dir,
use_gpu,
use_trt,
use_mkl,
mkl_thread_num,
gpu_id,
key,
use_ir_optim);
}
void create_predictor(const std::string& model_dir,
bool use_gpu = false,
bool use_trt = false,
bool use_mkl = true,
int mkl_thread_num = 4,
int gpu_id = 0,
std::string key = "",
bool use_ir_optim = true);
......
......@@ -37,7 +37,7 @@ struct Mask {
};
/*
* @brief
* @brief
* This class represents target box in detection or instance segmentation tasks.
* */
struct Box {
......@@ -47,7 +47,7 @@ struct Box {
// confidence score
float score;
std::vector<float> coordinate;
Mask<float> mask;
Mask<int> mask;
};
/*
......
......@@ -21,6 +21,7 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include <iostream>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
......@@ -216,8 +217,7 @@ class Padding : public Transform {
}
if (item["im_padding_value"].IsDefined()) {
im_value_ = item["im_padding_value"].as<std::vector<float>>();
}
else {
} else {
im_value_ = {0, 0, 0};
}
}
......
......@@ -11,16 +11,25 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include <omp.h>
#include <algorithm>
#include <fstream>
#include <cstring>
#include "include/paddlex/paddlex.h"
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
namespace PaddleX {
void Model::create_predictor(const std::string& model_dir,
bool use_gpu,
bool use_trt,
bool use_mkl,
int mkl_thread_num,
int gpu_id,
std::string key,
bool use_ir_optim) {
......@@ -40,7 +49,7 @@ void Model::create_predictor(const std::string& model_dir,
}
#endif
if (yaml_input == "") {
// 读取配置文件
// read yaml file
std::ifstream yaml_fin(yaml_file);
yaml_fin.seekg(0, std::ios::end);
size_t yaml_file_size = yaml_fin.tellg();
......@@ -48,7 +57,7 @@ void Model::create_predictor(const std::string& model_dir,
yaml_fin.seekg(0);
yaml_fin.read(&yaml_input[0], yaml_file_size);
}
// 读取配置文件内容
// load yaml file
if (!load_config(yaml_input)) {
std::cerr << "Parse file 'model.yml' failed!" << std::endl;
exit(-1);
......@@ -57,6 +66,10 @@ void Model::create_predictor(const std::string& model_dir,
if (key == "") {
config.SetModel(model_file, params_file);
}
if (use_mkl && name != "HRNet" && name != "DeepLabv3p") {
config.EnableMKLDNN();
config.SetCpuMathLibraryNumThreads(mkl_thread_num);
}
if (use_gpu) {
config.EnableUseGpu(100, gpu_id);
} else {
......@@ -64,13 +77,13 @@ void Model::create_predictor(const std::string& model_dir,
}
config.SwitchUseFeedFetchOps(false);
config.SwitchSpecifyInputNames(true);
// 开启图优化
// enable graph Optim
#if defined(__arm__) || defined(__aarch64__)
config.SwitchIrOptim(false);
#else
config.SwitchIrOptim(use_ir_optim);
#endif
// 开启内存优化
// enable Memory Optim
config.EnableMemoryOptim();
if (use_trt) {
config.EnableTensorRtEngine(
......@@ -108,9 +121,9 @@ bool Model::load_config(const std::string& yaml_input) {
return false;
}
}
// 构建数据处理流
// build data preprocess stream
transforms_.Init(config["Transforms"], to_rgb);
// 读入label list
// read label list
labels.clear();
for (const auto& item : config["_Attributes"]["labels"]) {
int index = labels.size();
......@@ -152,19 +165,19 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) {
"to function predict()!" << std::endl;
return false;
}
// 处理输入图像
// im preprocess
if (!preprocess(im, &inputs_)) {
std::cerr << "Preprocess failed!" << std::endl;
return false;
}
// 使用加载的模型进行预测
// predict
auto in_tensor = predictor_->GetInputTensor("image");
int h = inputs_.new_im_size_[0];
int w = inputs_.new_im_size_[1];
in_tensor->Reshape({1, 3, h, w});
in_tensor->copy_from_cpu(inputs_.im_data_.data());
predictor_->ZeroCopyRun();
// 取出模型的输出结果
// get result
auto output_names = predictor_->GetOutputNames();
auto output_tensor = predictor_->GetOutputTensor(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
......@@ -174,7 +187,7 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) {
}
outputs_.resize(size);
output_tensor->copy_to_cpu(outputs_.data());
// 对模型输出结果进行后处理
// postprocess
auto ptr = std::max_element(std::begin(outputs_), std::end(outputs_));
result->category_id = std::distance(std::begin(outputs_), ptr);
result->score = *ptr;
......@@ -198,12 +211,12 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch,
return false;
}
inputs_batch_.assign(im_batch.size(), ImageBlob());
// 处理输入图像
// preprocess
if (!preprocess(im_batch, &inputs_batch_, thread_num)) {
std::cerr << "Preprocess failed!" << std::endl;
return false;
}
// 使用加载的模型进行预测
// predict
int batch_size = im_batch.size();
auto in_tensor = predictor_->GetInputTensor("image");
int h = inputs_batch_[0].new_im_size_[0];
......@@ -218,7 +231,7 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch,
in_tensor->copy_from_cpu(inputs_data.data());
// in_tensor->copy_from_cpu(inputs_.im_data_.data());
predictor_->ZeroCopyRun();
// 取出模型的输出结果
// get result
auto output_names = predictor_->GetOutputNames();
auto output_tensor = predictor_->GetOutputTensor(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
......@@ -228,7 +241,7 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch,
}
outputs_.resize(size);
output_tensor->copy_to_cpu(outputs_.data());
// 对模型输出结果进行后处理
// postprocess
(*results).clear();
(*results).resize(batch_size);
int single_batch_size = size / batch_size;
......@@ -258,7 +271,7 @@ bool Model::predict(const cv::Mat& im, DetResult* result) {
return false;
}
// 处理输入图像
// preprocess
if (!preprocess(im, &inputs_)) {
std::cerr << "Preprocess failed!" << std::endl;
return false;
......@@ -288,7 +301,7 @@ bool Model::predict(const cv::Mat& im, DetResult* result) {
im_info_tensor->copy_from_cpu(im_info);
im_shape_tensor->copy_from_cpu(im_shape);
}
// 使用加载的模型进行预测
// predict
predictor_->ZeroCopyRun();
std::vector<float> output_box;
......@@ -306,7 +319,7 @@ bool Model::predict(const cv::Mat& im, DetResult* result) {
return true;
}
int num_boxes = size / 6;
// 解析预测框box
// box postprocess
for (int i = 0; i < num_boxes; ++i) {
Box box;
box.category_id = static_cast<int>(round(output_box[i * 6]));
......@@ -321,7 +334,7 @@ bool Model::predict(const cv::Mat& im, DetResult* result) {
box.coordinate = {xmin, ymin, w, h};
result->boxes.push_back(std::move(box));
}
// 实例分割需解析mask
// mask postprocess
if (name == "MaskRCNN") {
std::vector<float> output_mask;
auto output_mask_tensor = predictor_->GetOutputTensor(output_names[1]);
......@@ -337,12 +350,22 @@ bool Model::predict(const cv::Mat& im, DetResult* result) {
result->mask_resolution = output_mask_shape[2];
for (int i = 0; i < result->boxes.size(); ++i) {
Box* box = &result->boxes[i];
auto begin_mask =
output_mask.begin() + (i * classes + box->category_id) * mask_pixels;
auto end_mask = begin_mask + mask_pixels;
box->mask.data.assign(begin_mask, end_mask);
box->mask.shape = {static_cast<int>(box->coordinate[2]),
static_cast<int>(box->coordinate[3])};
auto begin_mask =
output_mask.data() + (i * classes + box->category_id) * mask_pixels;
cv::Mat bin_mask(result->mask_resolution,
result->mask_resolution,
CV_32FC1,
begin_mask);
cv::resize(bin_mask,
bin_mask,
cv::Size(box->mask.shape[0], box->mask.shape[1]));
cv::threshold(bin_mask, bin_mask, 0.5, 1, cv::THRESH_BINARY);
auto mask_int_begin = reinterpret_cast<float*>(bin_mask.data);
auto mask_int_end =
mask_int_begin + box->mask.shape[0] * box->mask.shape[1];
box->mask.data.assign(mask_int_begin, mask_int_end);
}
}
return true;
......@@ -366,12 +389,12 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch,
inputs_batch_.assign(im_batch.size(), ImageBlob());
int batch_size = im_batch.size();
// 处理输入图像
// preprocess
if (!preprocess(im_batch, &inputs_batch_, thread_num)) {
std::cerr << "Preprocess failed!" << std::endl;
return false;
}
// 对RCNN类模型做批量padding
// RCNN model padding
if (batch_size > 1) {
if (name == "FasterRCNN" || name == "MaskRCNN") {
int max_h = -1;
......@@ -452,10 +475,10 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch,
im_info_tensor->copy_from_cpu(im_info.data());
im_shape_tensor->copy_from_cpu(im_shape.data());
}
// 使用加载的模型进行预测
// predict
predictor_->ZeroCopyRun();
// 读取所有box
// get all box
std::vector<float> output_box;
auto output_names = predictor_->GetOutputNames();
auto output_box_tensor = predictor_->GetOutputTensor(output_names[0]);
......@@ -472,7 +495,7 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch,
}
auto lod_vector = output_box_tensor->lod();
int num_boxes = size / 6;
// 解析预测框box
// box postprocess
(*results).clear();
(*results).resize(batch_size);
for (int i = 0; i < lod_vector[0].size() - 1; ++i) {
......@@ -492,7 +515,7 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch,
}
}
// 实例分割需解析mask
// mask postprocess
if (name == "MaskRCNN") {
std::vector<float> output_mask;
auto output_mask_tensor = predictor_->GetOutputTensor(output_names[1]);
......@@ -509,14 +532,24 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch,
for (int i = 0; i < lod_vector[0].size() - 1; ++i) {
(*results)[i].mask_resolution = output_mask_shape[2];
for (int j = 0; j < (*results)[i].boxes.size(); ++j) {
Box* box = &(*results)[i].boxes[j];
Box* box = &(*results)[i].boxes[i];
int category_id = box->category_id;
auto begin_mask = output_mask.begin() +
(mask_idx * classes + category_id) * mask_pixels;
auto end_mask = begin_mask + mask_pixels;
box->mask.data.assign(begin_mask, end_mask);
box->mask.shape = {static_cast<int>(box->coordinate[2]),
static_cast<int>(box->coordinate[3])};
static_cast<int>(box->coordinate[3])};
auto begin_mask =
output_mask.data() + (i * classes + box->category_id) * mask_pixels;
cv::Mat bin_mask(output_mask_shape[2],
output_mask_shape[2],
CV_32FC1,
begin_mask);
cv::resize(bin_mask,
bin_mask,
cv::Size(box->mask.shape[0], box->mask.shape[1]));
cv::threshold(bin_mask, bin_mask, 0.5, 1, cv::THRESH_BINARY);
auto mask_int_begin = reinterpret_cast<float*>(bin_mask.data);
auto mask_int_end =
mask_int_begin + box->mask.shape[0] * box->mask.shape[1];
box->mask.data.assign(mask_int_begin, mask_int_end);
mask_idx++;
}
}
......@@ -537,7 +570,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
return false;
}
// 处理输入图像
// preprocess
if (!preprocess(im, &inputs_)) {
std::cerr << "Preprocess failed!" << std::endl;
return false;
......@@ -549,10 +582,10 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
im_tensor->Reshape({1, 3, h, w});
im_tensor->copy_from_cpu(inputs_.im_data_.data());
// 使用加载的模型进行预测
// predict
predictor_->ZeroCopyRun();
// 获取预测置信度,经过argmax后的labelmap
// get labelmap
auto output_names = predictor_->GetOutputNames();
auto output_label_tensor = predictor_->GetOutputTensor(output_names[0]);
std::vector<int> output_label_shape = output_label_tensor->shape();
......@@ -565,7 +598,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
result->label_map.data.resize(size);
output_label_tensor->copy_to_cpu(result->label_map.data.data());
// 获取预测置信度scoremap
// get scoremap
auto output_score_tensor = predictor_->GetOutputTensor(output_names[1]);
std::vector<int> output_score_shape = output_score_tensor->shape();
size = 1;
......@@ -577,7 +610,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
result->score_map.data.resize(size);
output_score_tensor->copy_to_cpu(result->score_map.data.data());
// 解析输出结果到原图大小
// get origin image result
std::vector<uint8_t> label_map(result->label_map.data.begin(),
result->label_map.data.end());
cv::Mat mask_label(result->label_map.shape[1],
......@@ -647,7 +680,7 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch,
return false;
}
// 处理输入图像
// preprocess
inputs_batch_.assign(im_batch.size(), ImageBlob());
if (!preprocess(im_batch, &inputs_batch_, thread_num)) {
std::cerr << "Preprocess failed!" << std::endl;
......@@ -670,10 +703,10 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch,
im_tensor->copy_from_cpu(inputs_data.data());
// im_tensor->copy_from_cpu(inputs_.im_data_.data());
// 使用加载的模型进行预测
// predict
predictor_->ZeroCopyRun();
// 获取预测置信度,经过argmax后的labelmap
// get labelmap
auto output_names = predictor_->GetOutputNames();
auto output_label_tensor = predictor_->GetOutputTensor(output_names[0]);
std::vector<int> output_label_shape = output_label_tensor->shape();
......@@ -698,7 +731,7 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch,
(*results)[i].label_map.data.data());
}
// 获取预测置信度scoremap
// get scoremap
auto output_score_tensor = predictor_->GetOutputTensor(output_names[1]);
std::vector<int> output_score_shape = output_score_tensor->shape();
size = 1;
......@@ -722,7 +755,7 @@ bool Model::predict(const std::vector<cv::Mat>& im_batch,
(*results)[i].score_map.data.data());
}
// 解析输出结果到原图大小
// get origin image result
for (int i = 0; i < batch_size; ++i) {
std::vector<uint8_t> label_map((*results)[i].label_map.data.begin(),
(*results)[i].label_map.data.end());
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "include/paddlex/transforms.h"
#include <math.h>
#include <iostream>
#include <string>
#include <vector>
#include <math.h>
#include "include/paddlex/transforms.h"
namespace PaddleX {
......@@ -195,7 +197,7 @@ std::shared_ptr<Transform> Transforms::CreateTransform(
}
bool Transforms::Run(cv::Mat* im, ImageBlob* data) {
// 按照transforms中预处理算子顺序处理图像
// do all preprocess ops by order
if (to_rgb_) {
cv::cvtColor(*im, *im, cv::COLOR_BGR2RGB);
}
......@@ -211,8 +213,8 @@ bool Transforms::Run(cv::Mat* im, ImageBlob* data) {
}
}
// 将图像由NHWC转为NCHW格式
// 同时转为连续的内存块存储到ImageBlob
// data format NHWC to NCHW
// img data save to ImageBlob
int h = im->rows;
int w = im->cols;
int c = im->channels();
......
......@@ -47,7 +47,7 @@ cv::Mat Visualize(const cv::Mat& img,
boxes[i].coordinate[2],
boxes[i].coordinate[3]);
// 生成预测框和标题
// draw box and title
std::string text = boxes[i].category;
int c1 = colormap[3 * boxes[i].category_id + 0];
int c2 = colormap[3 * boxes[i].category_id + 1];
......@@ -63,13 +63,13 @@ cv::Mat Visualize(const cv::Mat& img,
origin.x = roi.x;
origin.y = roi.y;
// 生成预测框标题的背景
// background
cv::Rect text_back = cv::Rect(boxes[i].coordinate[0],
boxes[i].coordinate[1] - text_size.height,
text_size.width,
text_size.height);
// 绘图和文字
// draw
cv::rectangle(vis_img, roi, roi_color, 2);
cv::rectangle(vis_img, text_back, roi_color, -1);
cv::putText(vis_img,
......@@ -80,18 +80,16 @@ cv::Mat Visualize(const cv::Mat& img,
cv::Scalar(255, 255, 255),
thickness);
// 生成实例分割mask
// mask
if (boxes[i].mask.data.size() == 0) {
continue;
}
cv::Mat bin_mask(result.mask_resolution,
result.mask_resolution,
std::vector<float> mask_data;
mask_data.assign(boxes[i].mask.data.begin(), boxes[i].mask.data.end());
cv::Mat bin_mask(boxes[i].mask.shape[1],
boxes[i].mask.shape[0],
CV_32FC1,
boxes[i].mask.data.data());
cv::resize(bin_mask,
bin_mask,
cv::Size(boxes[i].mask.shape[0], boxes[i].mask.shape[1]));
cv::threshold(bin_mask, bin_mask, 0.5, 1, cv::THRESH_BINARY);
cv::Mat full_mask = cv::Mat::zeros(vis_img.size(), CV_8UC1);
bin_mask.copyTo(full_mask(roi));
cv::Mat mask_ch[3];
......
......@@ -8,7 +8,9 @@ SET(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake" ${CMAKE_MODULE_PATH})
SET(OPENVINO_DIR "" CACHE PATH "Location of libraries")
SET(OPENCV_DIR "" CACHE PATH "Location of libraries")
SET(GFLAGS_DIR "" CACHE PATH "Location of libraries")
SET(GLOG_DIR "" CACHE PATH "Location of libraries")
SET(NGRAPH_LIB "" CACHE PATH "Location of libraries")
SET(ARCH "" CACHE PATH "Location of libraries")
include(cmake/yaml-cpp.cmake)
......@@ -27,6 +29,12 @@ macro(safe_set_static_flag)
endforeach(flag_var)
endmacro()
if(NOT WIN32)
if (NOT DEFINED ARCH OR ${ARCH} STREQUAL "")
message(FATAL_ERROR "please set ARCH with -DARCH=x86 OR armv7")
endif()
endif()
if (NOT DEFINED OPENVINO_DIR OR ${OPENVINO_DIR} STREQUAL "")
message(FATAL_ERROR "please set OPENVINO_DIR with -DOPENVINO_DIR=/path/influence_engine")
endif()
......@@ -39,19 +47,32 @@ if (NOT DEFINED GFLAGS_DIR OR ${GFLAGS_DIR} STREQUAL "")
message(FATAL_ERROR "please set GFLAGS_DIR with -DGFLAGS_DIR=/path/gflags")
endif()
if (NOT DEFINED GLOG_DIR OR ${GLOG_DIR} STREQUAL "")
message(FATAL_ERROR "please set GLOG_DIR with -DLOG_DIR=/path/glog")
endif()
if (NOT DEFINED NGRAPH_LIB OR ${NGRAPH_LIB} STREQUAL "")
message(FATAL_ERROR "please set NGRAPH_DIR with -DNGRAPH_DIR=/path/ngraph")
endif()
include_directories("${OPENVINO_DIR}")
link_directories("${OPENVINO_DIR}/lib")
include_directories("${OPENVINO_DIR}/include")
link_directories("${OPENVINO_DIR}/external/tbb/lib")
include_directories("${OPENVINO_DIR}/external/tbb/include/tbb")
link_directories("${OPENVINO_DIR}/lib")
link_directories("${OPENVINO_DIR}/external/tbb/lib")
if(WIN32)
link_directories("${OPENVINO_DIR}/lib/intel64/Release")
link_directories("${OPENVINO_DIR}/bin/intel64/Release")
endif()
link_directories("${GFLAGS_DIR}/lib")
include_directories("${GFLAGS_DIR}/include")
link_directories("${GLOG_DIR}/lib")
include_directories("${GLOG_DIR}/include")
link_directories("${NGRAPH_LIB}")
link_directories("${NGRAPH_LIB}/lib")
......@@ -79,14 +100,29 @@ else()
set(CMAKE_STATIC_LIBRARY_PREFIX "")
endif()
if(WITH_STATIC_LIB)
set(DEPS ${OPENVINO_DIR}/lib/intel64/libinference_engine${CMAKE_STATIC_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${OPENVINO_DIR}/lib/intel64/libinference_engine_legacy${CMAKE_STATIC_LIBRARY_SUFFIX})
if(WIN32)
set(DEPS ${OPENVINO_DIR}/lib/intel64/Release/inference_engine${CMAKE_STATIC_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${OPENVINO_DIR}/lib/intel64/Release/inference_engine_legacy${CMAKE_STATIC_LIBRARY_SUFFIX})
else()
set(DEPS ${OPENVINO_DIR}/lib/intel64/libinference_engine${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${OPENVINO_DIR}/lib/intel64/libinference_engine_legacy${CMAKE_SHARED_LIBRARY_SUFFIX})
endif()
if (ARCH STREQUAL "armv7")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv7-a")
if(WITH_STATIC_LIB)
set(DEPS ${OPENVINO_DIR}/lib/armv7l/libinference_engine${CMAKE_STATIC_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${OPENVINO_DIR}/lib/armv7l/libinference_engine_legacy${CMAKE_STATIC_LIBRARY_SUFFIX})
else()
set(DEPS ${OPENVINO_DIR}/lib/armv7l/libinference_engine${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${OPENVINO_DIR}/lib/armv7l/libinference_engine_legacy${CMAKE_SHARED_LIBRARY_SUFFIX})
endif()
else()
if(WITH_STATIC_LIB)
set(DEPS ${OPENVINO_DIR}/lib/intel64/libinference_engine${CMAKE_STATIC_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${OPENVINO_DIR}/lib/intel64/libinference_engine_legacy${CMAKE_STATIC_LIBRARY_SUFFIX})
else()
set(DEPS ${OPENVINO_DIR}/lib/intel64/libinference_engine${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${OPENVINO_DIR}/lib/intel64/libinference_engine_legacy${CMAKE_SHARED_LIBRARY_SUFFIX})
endif()
endif()
endif(WIN32)
if (NOT WIN32)
set(DEPS ${DEPS}
......@@ -94,7 +130,7 @@ if (NOT WIN32)
)
else()
set(DEPS ${DEPS}
glog gflags_static libprotobuf zlibstatic xxhash libyaml-cppmt)
glog gflags_static libyaml-cppmt)
set(DEPS ${DEPS} libcmt shlwapi)
endif(NOT WIN32)
......@@ -105,7 +141,14 @@ if (NOT WIN32)
endif()
set(DEPS ${DEPS} ${OpenCV_LIBS})
add_executable(classifier src/classifier.cpp src/transforms.cpp src/paddlex.cpp)
add_executable(classifier demo/classifier.cpp src/transforms.cpp src/paddlex.cpp)
ADD_DEPENDENCIES(classifier ext-yaml-cpp)
target_link_libraries(classifier ${DEPS})
add_executable(segmenter demo/segmenter.cpp src/transforms.cpp src/paddlex.cpp src/visualize.cpp)
ADD_DEPENDENCIES(segmenter ext-yaml-cpp)
target_link_libraries(segmenter ${DEPS})
add_executable(detector demo/detector.cpp src/transforms.cpp src/paddlex.cpp src/visualize.cpp)
ADD_DEPENDENCIES(detector ext-yaml-cpp)
target_link_libraries(detector ${DEPS})
{
"configurations": [
"configurations": [
{
"name": "x64-Release",
"generator": "Ninja",
"configurationType": "RelWithDebInfo",
"inheritEnvironments": [ "msvc_x64_x64" ],
"buildRoot": "${projectDir}\\out\\build\\${name}",
"installRoot": "${projectDir}\\out\\install\\${name}",
"cmakeCommandArgs": "",
"buildCommandArgs": "-v",
"ctestCommandArgs": "",
"variables": [
{
"name": "x64-Release",
"generator": "Ninja",
"configurationType": "RelWithDebInfo",
"inheritEnvironments": [ "msvc_x64_x64" ],
"buildRoot": "${projectDir}\\out\\build\\${name}",
"installRoot": "${projectDir}\\out\\install\\${name}",
"cmakeCommandArgs": "",
"buildCommandArgs": "-v",
"ctestCommandArgs": "",
"variables": [
{
"name": "OPENCV_DIR",
"value": "C:/projects/opencv",
"type": "PATH"
},
{
"name": "OPENVINO_LIB",
"value": "C:/projetcs/inference_engine",
"type": "PATH"
}
]
"name": "OPENCV_DIR",
"value": "/path/to/opencv",
"type": "PATH"
},
{
"name": "OPENVINO_DIR",
"value": "C:/Program Files (x86)/IntelSWTools/openvino/deployment_tools/inference_engine",
"type": "PATH"
},
{
"name": "NGRAPH_LIB",
"value": "C:/Program Files (x86)/IntelSWTools/openvino/deployment_tools/ngraph/lib",
"type": "PATH"
},
{
"name": "GFLAGS_DIR",
"value": "/path/to/gflags",
"type": "PATH"
},
{
"name": "WITH_STATIC_LIB",
"value": "True",
"type": "BOOL"
},
{
"name": "GLOG_DIR",
"value": "/path/to/glog",
"type": "PATH"
}
]
}
]
}
]
}
\ No newline at end of file
find_package(Git REQUIRED)
include(ExternalProject)
......
......@@ -22,7 +22,7 @@
#include "include/paddlex/paddlex.h"
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_string(cfg_dir, "", "Path of inference model");
DEFINE_string(cfg_file, "", "Path of PaddelX model yml file");
DEFINE_string(device, "CPU", "Device name");
DEFINE_string(image, "", "Path of test image file");
DEFINE_string(image_list, "", "Path of test image list file");
......@@ -35,8 +35,8 @@ int main(int argc, char** argv) {
std::cerr << "--model_dir need to be defined" << std::endl;
return -1;
}
if (FLAGS_cfg_dir == "") {
std::cerr << "--cfg_dir need to be defined" << std::endl;
if (FLAGS_cfg_file == "") {
std::cerr << "--cfg_file need to be defined" << std::endl;
return -1;
}
if (FLAGS_image == "" & FLAGS_image_list == "") {
......@@ -44,11 +44,11 @@ int main(int argc, char** argv) {
return -1;
}
// 加载模型
// load model
PaddleX::Model model;
model.Init(FLAGS_model_dir, FLAGS_cfg_dir, FLAGS_device);
model.Init(FLAGS_model_dir, FLAGS_cfg_file, FLAGS_device);
// 进行预测
// predict
if (FLAGS_image_list != "") {
std::ifstream inf(FLAGS_image_list);
if (!inf) {
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <omp.h>
#include <algorithm>
#include <chrono> // NOLINT
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include <utility>
#include "include/paddlex/paddlex.h"
#include "include/paddlex/visualize.h"
using namespace std::chrono; // NOLINT
DEFINE_string(model_dir, "", "Path of openvino model xml file");
DEFINE_string(cfg_file, "", "Path of PaddleX model yaml file");
DEFINE_string(image, "", "Path of test image file");
DEFINE_string(image_list, "", "Path of test image list file");
DEFINE_string(device, "CPU", "Device name");
DEFINE_string(save_dir, "", "Path to save visualized image");
DEFINE_int32(batch_size, 1, "Batch size of infering");
DEFINE_double(threshold,
0.5,
"The minimum scores of target boxes which are shown");
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_model_dir == "") {
std::cerr << "--model_dir need to be defined" << std::endl;
return -1;
}
if (FLAGS_cfg_file == "") {
std::cerr << "--cfg_file need to be defined" << std::endl;
return -1;
}
if (FLAGS_image == "" & FLAGS_image_list == "") {
std::cerr << "--image or --image_list need to be defined" << std::endl;
return -1;
}
// load model
PaddleX::Model model;
model.Init(FLAGS_model_dir, FLAGS_cfg_file, FLAGS_device);
int imgs = 1;
auto colormap = PaddleX::GenerateColorMap(model.labels.size());
// predict
if (FLAGS_image_list != "") {
std::ifstream inf(FLAGS_image_list);
if (!inf) {
std::cerr << "Fail to open file " << FLAGS_image_list << std::endl;
return -1;
}
std::string image_path;
while (getline(inf, image_path)) {
PaddleX::DetResult result;
cv::Mat im = cv::imread(image_path, 1);
model.predict(im, &result);
if (FLAGS_save_dir != "") {
cv::Mat vis_img = PaddleX::Visualize(
im, result, model.labels, colormap, FLAGS_threshold);
std::string save_path =
PaddleX::generate_save_path(FLAGS_save_dir, FLAGS_image);
cv::imwrite(save_path, vis_img);
std::cout << "Visualized output saved as " << save_path << std::endl;
}
}
} else {
PaddleX::DetResult result;
cv::Mat im = cv::imread(FLAGS_image, 1);
model.predict(im, &result);
for (int i = 0; i < result.boxes.size(); ++i) {
std::cout << "image file: " << FLAGS_image << std::endl;
std::cout << ", predict label: " << result.boxes[i].category
<< ", label_id:" << result.boxes[i].category_id
<< ", score: " << result.boxes[i].score
<< ", box(xmin, ymin, w, h):(" << result.boxes[i].coordinate[0]
<< ", " << result.boxes[i].coordinate[1] << ", "
<< result.boxes[i].coordinate[2] << ", "
<< result.boxes[i].coordinate[3] << ")" << std::endl;
}
if (FLAGS_save_dir != "") {
// visualize
cv::Mat vis_img = PaddleX::Visualize(
im, result, model.labels, colormap, FLAGS_threshold);
std::string save_path =
PaddleX::generate_save_path(FLAGS_save_dir, FLAGS_image);
cv::imwrite(save_path, vis_img);
result.clear();
std::cout << "Visualized output saved as " << save_path << std::endl;
}
}
return 0;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <algorithm>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include <utility>
#include "include/paddlex/paddlex.h"
#include "include/paddlex/visualize.h"
DEFINE_string(model_dir, "", "Path of openvino model xml file");
DEFINE_string(cfg_file, "", "Path of PaddleX model yaml file");
DEFINE_string(image, "", "Path of test image file");
DEFINE_string(image_list, "", "Path of test image list file");
DEFINE_string(device, "CPU", "Device name");
DEFINE_string(save_dir, "", "Path to save visualized image");
DEFINE_int32(batch_size, 1, "Batch size of infering");
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_model_dir == "") {
std::cerr << "--model_dir need to be defined" << std::endl;
return -1;
}
if (FLAGS_cfg_file == "") {
std::cerr << "--cfg_file need to be defined" << std::endl;
return -1;
}
if (FLAGS_image == "" & FLAGS_image_list == "") {
std::cerr << "--image or --image_list need to be defined" << std::endl;
return -1;
}
// load model
PaddleX::Model model;
model.Init(FLAGS_model_dir, FLAGS_cfg_file, FLAGS_device);
int imgs = 1;
auto colormap = PaddleX::GenerateColorMap(model.labels.size());
if (FLAGS_image_list != "") {
std::ifstream inf(FLAGS_image_list);
if (!inf) {
std::cerr << "Fail to open file " << FLAGS_image_list <<std::endl;
return -1;
}
std::string image_path;
while (getline(inf, image_path)) {
PaddleX::SegResult result;
cv::Mat im = cv::imread(image_path, 1);
model.predict(im, &result);
if (FLAGS_save_dir != "") {
cv::Mat vis_img = PaddleX::Visualize(im, result, model.labels, colormap);
std::string save_path =
PaddleX::generate_save_path(FLAGS_save_dir, image_path);
cv::imwrite(save_path, vis_img);
std::cout << "Visualized output saved as " << save_path << std::endl;
}
}
} else {
PaddleX::SegResult result;
cv::Mat im = cv::imread(FLAGS_image, 1);
model.predict(im, &result);
if (FLAGS_save_dir != "") {
cv::Mat vis_img = PaddleX::Visualize(im, result, model.labels, colormap);
std::string save_path =
PaddleX::generate_save_path(FLAGS_save_dir, FLAGS_image);
cv::imwrite(save_path, vis_img);
std::cout << "Visualized` output saved as " << save_path << std::endl;
}
result.clear();
}
return 0;
}
......@@ -54,4 +54,4 @@ class ConfigPaser {
YAML::Node Transforms_;
};
} // namespace PaddleDetection
} // namespace PaddleX
......@@ -17,6 +17,8 @@
#include <functional>
#include <iostream>
#include <numeric>
#include <map>
#include <string>
#include "yaml-cpp/yaml.h"
......@@ -30,35 +32,40 @@
#include "include/paddlex/config_parser.h"
#include "include/paddlex/results.h"
#include "include/paddlex/transforms.h"
using namespace InferenceEngine;
namespace PaddleX {
class Model {
public:
void Init(const std::string& model_dir,
const std::string& cfg_dir,
const std::string& cfg_file,
std::string device) {
create_predictor(model_dir, cfg_dir, device);
create_predictor(model_dir, cfg_file, device);
}
void create_predictor(const std::string& model_dir,
const std::string& cfg_dir,
const std::string& cfg_file,
std::string device);
bool load_config(const std::string& model_dir);
bool preprocess(cv::Mat* input_im);
bool preprocess(cv::Mat* input_im, ImageBlob* inputs);
bool predict(const cv::Mat& im, ClsResult* result);
bool predict(const cv::Mat& im, DetResult* result);
bool predict(const cv::Mat& im, SegResult* result);
std::string type;
std::string name;
std::vector<std::string> labels;
std::map<int, std::string> labels;
Transforms transforms_;
Blob::Ptr inputs_;
Blob::Ptr output_;
CNNNetwork network_;
ExecutableNetwork executable_network_;
ImageBlob inputs_;
InferenceEngine::Blob::Ptr output_;
InferenceEngine::CNNNetwork network_;
InferenceEngine::ExecutableNetwork executable_network_;
};
} // namespce of PaddleX
} // namespace PaddleX
......@@ -61,11 +61,11 @@ class DetResult : public BaseResult {
class SegResult : public BaseResult {
public:
Mask<int64_t> label_map;
Mask<int> label_map;
Mask<float> score_map;
void clear() {
label_map.clear();
score_map.clear();
}
};
} // namespce of PaddleX
} // namespace PaddleX
......@@ -16,26 +16,54 @@
#include <yaml-cpp/yaml.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <memory>
#include <string>
#include <vector>
#include <iostream>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <inference_engine.hpp>
using namespace InferenceEngine;
namespace PaddleX {
/*
* @brief
* This class represents object for storing all preprocessed data
* */
class ImageBlob {
public:
// Original image height and width
InferenceEngine::Blob::Ptr ori_im_size_;
// Newest image height and width after process
std::vector<int> new_im_size_ = std::vector<int>(2);
// Image height and width before resize
std::vector<std::vector<int>> im_size_before_resize_;
// Reshape order
std::vector<std::string> reshape_order_;
// Resize scale
float scale = 1.0;
// Buffer for image data after preprocessing
InferenceEngine::Blob::Ptr blob;
void clear() {
im_size_before_resize_.clear();
reshape_order_.clear();
}
};
// Abstraction of preprocessing opration class
class Transform {
public:
virtual void Init(const YAML::Node& item) = 0;
virtual bool Run(cv::Mat* im) = 0;
virtual bool Run(cv::Mat* im, ImageBlob* data) = 0;
};
class Normalize : public Transform {
......@@ -45,7 +73,7 @@ class Normalize : public Transform {
std_ = item["std"].as<std::vector<float>>();
}
virtual bool Run(cv::Mat* im);
virtual bool Run(cv::Mat* im, ImageBlob* data);
private:
std::vector<float> mean_;
......@@ -61,8 +89,8 @@ class ResizeByShort : public Transform {
} else {
max_size_ = -1;
}
};
virtual bool Run(cv::Mat* im);
}
virtual bool Run(cv::Mat* im, ImageBlob* data);
private:
float GenerateScale(const cv::Mat& im);
......@@ -70,6 +98,55 @@ class ResizeByShort : public Transform {
int max_size_;
};
/*
* @brief
* This class execute resize by long operation on image matrix. At first, it resizes
* the long side of image matrix to specified length. Accordingly, the short side
* will be resized in the same proportion.
* */
class ResizeByLong : public Transform {
public:
virtual void Init(const YAML::Node& item) {
long_size_ = item["long_size"].as<int>();
}
virtual bool Run(cv::Mat* im, ImageBlob* data);
private:
int long_size_;
};
/*
* @brief
* This class execute resize operation on image matrix. It resizes width and height
* to specified length.
* */
class Resize : public Transform {
public:
virtual void Init(const YAML::Node& item) {
if (item["interp"].IsDefined()) {
interp_ = item["interp"].as<std::string>();
}
if (item["target_size"].IsScalar()) {
height_ = item["target_size"].as<int>();
width_ = item["target_size"].as<int>();
} else if (item["target_size"].IsSequence()) {
std::vector<int> target_size = item["target_size"].as<std::vector<int>>();
width_ = target_size[0];
height_ = target_size[1];
}
if (height_ <= 0 || width_ <= 0) {
std::cerr << "[Resize] target_size should greater than 0" << std::endl;
exit(-1);
}
}
virtual bool Run(cv::Mat* im, ImageBlob* data);
private:
int height_;
int width_;
std::string interp_;
};
class CenterCrop : public Transform {
public:
......@@ -83,22 +160,65 @@ class CenterCrop : public Transform {
height_ = crop_size[1];
}
}
virtual bool Run(cv::Mat* im);
virtual bool Run(cv::Mat* im, ImageBlob* data);
private:
int height_;
int width_;
};
/*
* @brief
* This class execute padding operation on image matrix. It makes border on edge
* of image matrix.
* */
class Padding : public Transform {
public:
virtual void Init(const YAML::Node& item) {
if (item["coarsest_stride"].IsDefined()) {
coarsest_stride_ = item["coarsest_stride"].as<int>();
if (coarsest_stride_ < 1) {
std::cerr << "[Padding] coarest_stride should greater than 0"
<< std::endl;
exit(-1);
}
}
if (item["target_size"].IsDefined()) {
if (item["target_size"].IsScalar()) {
width_ = item["target_size"].as<int>();
height_ = item["target_size"].as<int>();
} else if (item["target_size"].IsSequence()) {
width_ = item["target_size"].as<std::vector<int>>()[0];
height_ = item["target_size"].as<std::vector<int>>()[1];
}
}
if (item["im_padding_value"].IsDefined()) {
im_value_ = item["im_padding_value"].as<std::vector<float>>();
} else {
im_value_ = {0, 0, 0};
}
}
virtual bool Run(cv::Mat* im, ImageBlob* data);
private:
int coarsest_stride_ = -1;
int width_ = 0;
int height_ = 0;
std::vector<float> im_value_;
};
class Transforms {
public:
void Init(const YAML::Node& node, bool to_rgb = true);
void Init(const YAML::Node& node, std::string type, bool to_rgb = true);
std::shared_ptr<Transform> CreateTransform(const std::string& name);
bool Run(cv::Mat* im, Blob::Ptr blob);
bool Run(cv::Mat* im, ImageBlob* data);
private:
std::vector<std::shared_ptr<Transform>> transforms_;
bool to_rgb_ = true;
std::string type_;
};
} // namespace PaddleX
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iostream>
#include <map>
#include <vector>
#ifdef _WIN32
#include <direct.h>
#include <io.h>
#else // Linux/Unix
#include <dirent.h>
#include <sys/io.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#endif
#include <string>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include "include/paddlex/results.h"
#ifdef _WIN32
#define OS_PATH_SEP "\\"
#else
#define OS_PATH_SEP "/"
#endif
namespace PaddleX {
/*
* @brief
* Generate visualization colormap for each class
*
* @param number of class
* @return color map, the size of vector is 3 * num_class
* */
std::vector<int> GenerateColorMap(int num_class);
/*
* @brief
* Visualize the detection result
*
* @param img: initial image matrix
* @param results: the detection result
* @param labels: label map
* @param colormap: visualization color map
* @return visualized image matrix
* */
cv::Mat Visualize(const cv::Mat& img,
const DetResult& results,
const std::map<int, std::string>& labels,
const std::vector<int>& colormap,
float threshold = 0.5);
/*
* @brief
* Visualize the segmentation result
*
* @param img: initial image matrix
* @param results: the detection result
* @param labels: label map
* @param colormap: visualization color map
* @return visualized image matrix
* */
cv::Mat Visualize(const cv::Mat& img,
const SegResult& result,
const std::map<int, std::string>& labels,
const std::vector<int>& colormap);
/*
* @brief
* generate save path for visualized image matrix
*
* @param save_dir: directory for saving visualized image matrix
* @param file_path: sourcen image file path
* @return path of saving visualized result
* */
std::string generate_save_path(const std::string& save_dir,
const std::string& file_path);
} // namespace PaddleX
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from six import text_type as _text_type
import argparse
import sys
from utils import logging
import paddlex as pdx
def arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_dir",
"-m",
type=_text_type,
default=None,
help="define model directory path")
parser.add_argument(
"--save_dir",
"-s",
type=_text_type,
default=None,
help="path to save inference model")
parser.add_argument(
"--fixed_input_shape",
"-fs",
default=None,
help="export openvino model with input shape:[w,h]")
parser.add_argument(
"--data_type",
"-dp",
default="FP32",
help="option, FP32 or FP16, the data_type of openvino IR")
return parser
def export_openvino_model(model, args):
if model.model_type == "detector" or model.__class__.__name__ == "FastSCNN":
logging.error(
"Only image classifier models and semantic segmentation models(except FastSCNN) are supported to export to openvino")
try:
import x2paddle
if x2paddle.__version__ < '0.7.4':
logging.error("You need to upgrade x2paddle >= 0.7.4")
except:
logging.error(
"You need to install x2paddle first, pip install x2paddle>=0.7.4")
import x2paddle.convert as x2pc
x2pc.paddle2onnx(args.model_dir, args.save_dir)
import mo.main as mo
from mo.utils.cli_parser import get_onnx_cli_parser
onnx_parser = get_onnx_cli_parser()
onnx_parser.add_argument("--model_dir",type=_text_type)
onnx_parser.add_argument("--save_dir",type=_text_type)
onnx_parser.add_argument("--fixed_input_shape")
onnx_input = os.path.join(args.save_dir, 'x2paddle_model.onnx')
onnx_parser.set_defaults(input_model=onnx_input)
onnx_parser.set_defaults(output_dir=args.save_dir)
shape = '[1,3,'
shape = shape + args.fixed_input_shape[1:]
if model.__class__.__name__ == "YOLOV3":
shape = shape + ",[1,2]"
inputs = "image,im_size"
onnx_parser.set_defaults(input = inputs)
onnx_parser.set_defaults(input_shape = shape)
mo.main(onnx_parser,'onnx')
def main():
parser = arg_parser()
args = parser.parse_args()
assert args.model_dir is not None, "--model_dir should be defined while exporting openvino model"
assert args.save_dir is not None, "--save_dir should be defined to create openvino model"
model = pdx.load_model(args.model_dir)
if model.status == "Normal" or model.status == "Prune":
logging.error(
"Only support inference model, try to export model first as below,",
exit=False)
export_openvino_model(model, args)
if __name__ == "__main__":
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import argparse
import deploy
def arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_dir",
"-m",
type=str,
default=None,
help="path to openvino model .xml file")
parser.add_argument(
"--device",
"-d",
type=str,
default='CPU',
help="Specify the target device to infer on:[CPU, GPU, FPGA, HDDL, MYRIAD,HETERO]"
"Default value is CPU")
parser.add_argument(
"--img", "-i", type=str, default=None, help="path to an image files")
parser.add_argument(
"--img_list", "-l", type=str, default=None, help="Path to a imglist")
parser.add_argument(
"--cfg_file",
"-c",
type=str,
default=None,
help="Path to PaddelX model yml file")
return parser
def main():
parser = arg_parser()
args = parser.parse_args()
model_xml = args.model_dir
model_yaml = args.cfg_file
#model init
if ("CPU" not in args.device):
predictor = deploy.Predictor(model_xml, model_yaml, args.device)
else:
predictor = deploy.Predictor(model_xml, model_yaml)
#predict
if (args.img_list != None):
f = open(args.img_list)
lines = f.readlines()
for im_path in lines:
print(im_path)
predictor.predict(im_path.strip('\n'))
f.close()
else:
im_path = args.img
predictor.predict(im_path)
if __name__ == "__main__":
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import os.path as osp
import time
import cv2
import numpy as np
import yaml
from six import text_type as _text_type
from openvino.inference_engine import IECore
class Predictor:
def __init__(self, model_xml, model_yaml, device="CPU"):
self.device = device
if not osp.exists(model_xml):
print("model xml file is not exists in {}".format(model_xml))
self.model_xml = model_xml
self.model_bin = osp.splitext(model_xml)[0] + ".bin"
if not osp.exists(model_yaml):
print("model yaml file is not exists in {}".format(model_yaml))
with open(model_yaml) as f:
self.info = yaml.load(f.read(), Loader=yaml.Loader)
self.model_type = self.info['_Attributes']['model_type']
self.model_name = self.info['Model']
self.num_classes = self.info['_Attributes']['num_classes']
self.labels = self.info['_Attributes']['labels']
if self.info['Model'] == 'MaskRCNN':
if self.info['_init_params']['with_fpn']:
self.mask_head_resolution = 28
else:
self.mask_head_resolution = 14
transforms_mode = self.info.get('TransformsMode', 'RGB')
if transforms_mode == 'RGB':
to_rgb = True
else:
to_rgb = False
self.transforms = self.build_transforms(self.info['Transforms'],
to_rgb)
self.predictor, self.net = self.create_predictor()
self.total_time = 0
self.count_num = 0
def create_predictor(self):
#initialization for specified device
print("Creating Inference Engine")
ie = IECore()
print("Loading network files:\n\t{}\n\t{}".format(self.model_xml,
self.model_bin))
net = ie.read_network(model=self.model_xml, weights=self.model_bin)
net.batch_size = 1
network_config = {}
if self.device == "MYRIAD":
network_config = {'VPU_HW_STAGES_OPTIMIZATION': 'NO'}
exec_net = ie.load_network(
network=net, device_name=self.device, config=network_config)
return exec_net, net
def build_transforms(self, transforms_info, to_rgb=True):
if self.model_type == "classifier":
import transforms.cls_transforms as transforms
elif self.model_type == "detector":
import transforms.det_transforms as transforms
elif self.model_type == "segmenter":
import transforms.seg_transforms as transforms
op_list = list()
for op_info in transforms_info:
op_name = list(op_info.keys())[0]
op_attr = op_info[op_name]
if not hasattr(transforms, op_name):
raise Exception(
"There's no operator named '{}' in transforms of {}".
format(op_name, self.model_type))
op_list.append(getattr(transforms, op_name)(**op_attr))
eval_transforms = transforms.Compose(op_list)
if hasattr(eval_transforms, 'to_rgb'):
eval_transforms.to_rgb = to_rgb
self.arrange_transforms(eval_transforms)
return eval_transforms
def arrange_transforms(self, eval_transforms):
if self.model_type == 'classifier':
import transforms.cls_transforms as transforms
arrange_transform = transforms.ArrangeClassifier
elif self.model_type == 'segmenter':
import transforms.seg_transforms as transforms
arrange_transform = transforms.ArrangeSegmenter
elif self.model_type == 'detector':
import transforms.det_transforms as transforms
arrange_name = 'Arrange{}'.format(self.model_name)
arrange_transform = getattr(transforms, arrange_name)
else:
raise Exception("Unrecognized model type: {}".format(
self.model_type))
if type(eval_transforms.transforms[-1]).__name__.startswith('Arrange'):
eval_transforms.transforms[-1] = arrange_transform(mode='test')
else:
eval_transforms.transforms.append(arrange_transform(mode='test'))
def raw_predict(self, preprocessed_input):
self.count_num += 1
feed_dict = {}
if self.model_name == "YOLOv3":
inputs = self.net.inputs
for name in inputs:
if (len(inputs[name].shape) == 2):
feed_dict[name] = preprocessed_input['im_size']
elif (len(inputs[name].shape) == 4):
feed_dict[name] = preprocessed_input['image']
else:
pass
else:
input_blob = next(iter(self.net.inputs))
feed_dict[input_blob] = preprocessed_input['image']
#Start sync inference
print("Starting inference in synchronous mode")
res = self.predictor.infer(inputs=feed_dict)
#Processing output blob
print("Processing output blob")
return res
def preprocess(self, image):
res = dict()
if self.model_type == "classifier":
im, = self.transforms(image)
im = np.expand_dims(im, axis=0).copy()
res['image'] = im
elif self.model_type == "detector":
if self.model_name == "YOLOv3":
im, im_shape = self.transforms(image)
im = np.expand_dims(im, axis=0).copy()
im_shape = np.expand_dims(im_shape, axis=0).copy()
res['image'] = im
res['im_size'] = im_shape
if self.model_name.count('RCNN') > 0:
im, im_resize_info, im_shape = self.transforms(image)
im = np.expand_dims(im, axis=0).copy()
im_resize_info = np.expand_dims(im_resize_info, axis=0).copy()
im_shape = np.expand_dims(im_shape, axis=0).copy()
res['image'] = im
res['im_info'] = im_resize_info
res['im_shape'] = im_shape
elif self.model_type == "segmenter":
im, im_info = self.transforms(image)
im = np.expand_dims(im, axis=0).copy()
res['image'] = im
res['im_info'] = im_info
return res
def classifier_postprocess(self, preds, topk=1):
""" 对分类模型的预测结果做后处理
"""
true_topk = min(self.num_classes, topk)
output_name = next(iter(self.net.outputs))
pred_label = np.argsort(-preds[output_name][0])[:true_topk]
result = [{
'category_id': l,
'category': self.labels[l],
'score': preds[output_name][0][l],
} for l in pred_label]
print(result)
return result
def segmenter_postprocess(self, preds, preprocessed_inputs):
""" 对语义分割结果做后处理
"""
it = iter(self.net.outputs)
next(it)
score_name = next(it)
score_map = np.squeeze(preds[score_name])
score_map = np.transpose(score_map, (1, 2, 0))
label_name = next(it)
label_map = np.squeeze(preds[label_name]).astype('uint8')
im_info = preprocessed_inputs['im_info']
for info in im_info[::-1]:
if info[0] == 'resize':
w, h = info[1][1], info[1][0]
label_map = cv2.resize(label_map, (w, h), cv2.INTER_NEAREST)
score_map = cv2.resize(score_map, (w, h), cv2.INTER_LINEAR)
elif info[0] == 'padding':
w, h = info[1][1], info[1][0]
label_map = label_map[0:h, 0:w]
score_map = score_map[0:h, 0:w, :]
else:
raise Exception("Unexpected info '{}' in im_info".format(info[
0]))
return {'label_map': label_map, 'score_map': score_map}
def detector_postprocess(self, preds, preprocessed_inputs):
"""对图像检测结果做后处理
"""
output_name = next(iter(self.net.outputs))
outputs = preds[output_name][0]
result = []
for out in outputs:
if (out[0] > 0):
result.append(out.tolist())
else:
pass
print(result)
return result
def predict(self, image, topk=1, threshold=0.5):
preprocessed_input = self.preprocess(image)
model_pred = self.raw_predict(preprocessed_input)
if self.model_type == "classifier":
results = self.classifier_postprocess(model_pred, topk)
elif self.model_type == "detector":
results = self.detector_postprocess(model_pred, preprocessed_input)
elif self.model_type == "segmenter":
results = self.segmenter_postprocess(model_pred,
preprocessed_input)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import cls_transforms
from . import det_transforms
from . import seg_transforms
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .ops import *
import random
import os.path as osp
import numpy as np
from PIL import Image, ImageEnhance
class ClsTransform:
"""分类Transform的基类
"""
def __init__(self):
pass
class Compose(ClsTransform):
"""根据数据预处理/增强算子对输入数据进行操作。
所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
Args:
transforms (list): 数据预处理/增强算子。
Raises:
TypeError: 形参数据类型不满足需求。
ValueError: 数据长度不匹配。
"""
def __init__(self, transforms):
if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!')
if len(transforms) < 1:
raise ValueError('The length of transforms ' + \
'must be equal or larger than 1!')
self.transforms = transforms
def __call__(self, im, label=None):
"""
Args:
im (str/np.ndarray): 图像路径/图像np.ndarray数据。
label (int): 每张图像所对应的类别序号。
Returns:
tuple: 根据网络所需字段所组成的tuple;
字段由transforms中的最后一个数据预处理操作决定。
"""
if isinstance(im, np.ndarray):
if len(im.shape) != 3:
raise Exception(
"im should be 3-dimension, but now is {}-dimensions".
format(len(im.shape)))
else:
try:
im = cv2.imread(im).astype('float32')
except:
raise TypeError('Can\'t read The image file {}!'.format(im))
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
for op in self.transforms:
outputs = op(im, label)
im = outputs[0]
if len(outputs) == 2:
label = outputs[1]
return outputs
def add_augmenters(self, augmenters):
if not isinstance(augmenters, list):
raise Exception(
"augmenters should be list type in func add_augmenters()")
transform_names = [type(x).__name__ for x in self.transforms]
for aug in augmenters:
if type(aug).__name__ in transform_names:
print(
"{} is already in ComposedTransforms, need to remove it from add_augmenters().".
format(type(aug).__name__))
self.transforms = augmenters + self.transforms
class Normalize(ClsTransform):
"""对图像进行标准化。
1. 对图像进行归一化到区间[0.0, 1.0]。
2. 对图像进行减均值除以标准差操作。
Args:
mean (list): 图像数据集的均值。默认为[0.485, 0.456, 0.406]。
std (list): 图像数据集的标准差。默认为[0.229, 0.224, 0.225]。
"""
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
self.mean = mean
self.std = std
def __call__(self, im, label=None):
"""
Args:
im (np.ndarray): 图像np.ndarray数据。
label (int): 每张图像所对应的类别序号。
Returns:
tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
"""
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im = normalize(im, mean, std)
if label is None:
return (im, )
else:
return (im, label)
class ResizeByShort(ClsTransform):
"""根据图像短边对图像重新调整大小(resize)。
1. 获取图像的长边和短边长度。
2. 根据短边与short_size的比例,计算长边的目标长度,
此时高、宽的resize比例为short_size/原图短边长度。
3. 如果max_size>0,调整resize比例:
如果长边的目标长度>max_size,则高、宽的resize比例为max_size/原图长边长度;
4. 根据调整大小的比例对图像进行resize。
Args:
short_size (int): 调整大小后的图像目标短边长度。默认为256。
max_size (int): 长边目标长度的最大限制。默认为-1。
"""
def __init__(self, short_size=256, max_size=-1):
self.short_size = short_size
self.max_size = max_size
def __call__(self, im, label=None):
"""
Args:
im (np.ndarray): 图像np.ndarray数据。
label (int): 每张图像所对应的类别序号。
Returns:
tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
"""
im_short_size = min(im.shape[0], im.shape[1])
im_long_size = max(im.shape[0], im.shape[1])
scale = float(self.short_size) / im_short_size
if self.max_size > 0 and np.round(scale *
im_long_size) > self.max_size:
scale = float(self.max_size) / float(im_long_size)
resized_width = int(round(im.shape[1] * scale))
resized_height = int(round(im.shape[0] * scale))
im = cv2.resize(
im, (resized_width, resized_height),
interpolation=cv2.INTER_LINEAR)
if label is None:
return (im, )
else:
return (im, label)
class CenterCrop(ClsTransform):
"""以图像中心点扩散裁剪长宽为`crop_size`的正方形
1. 计算剪裁的起始点。
2. 剪裁图像。
Args:
crop_size (int): 裁剪的目标边长。默认为224。
"""
def __init__(self, crop_size=224):
self.crop_size = crop_size
def __call__(self, im, label=None):
"""
Args:
im (np.ndarray): 图像np.ndarray数据。
label (int): 每张图像所对应的类别序号。
Returns:
tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
"""
im = center_crop(im, self.crop_size)
if label is None:
return (im, )
else:
return (im, label)
class ArrangeClassifier(ClsTransform):
"""获取训练/验证/预测所需信息。注意:此操作不需用户自己显示调用
Args:
mode (str): 指定数据用于何种用途,取值范围为['train', 'eval', 'test', 'quant']。
Raises:
ValueError: mode的取值不在['train', 'eval', 'test', 'quant']之内。
"""
def __init__(self, mode=None):
if mode not in ['train', 'eval', 'test', 'quant']:
raise ValueError(
"mode must be in ['train', 'eval', 'test', 'quant']!")
self.mode = mode
def __call__(self, im, label=None):
"""
Args:
im (np.ndarray): 图像np.ndarray数据。
label (int): 每张图像所对应的类别序号。
Returns:
tuple: 当mode为'train'或'eval'时,返回(im, label),分别对应图像np.ndarray数据、
图像类别id;当mode为'test'或'quant'时,返回(im, ),对应图像np.ndarray数据。
"""
im = permute(im, False).astype('float32')
if self.mode == 'train' or self.mode == 'eval':
outputs = (im, label)
else:
outputs = (im, )
return outputs
class ComposedClsTransforms(Compose):
""" 分类模型的基础Transforms流程,具体如下
训练阶段:
1. 随机从图像中crop一块子图,并resize成crop_size大小
2. 将1的输出按0.5的概率随机进行水平翻转
3. 将图像进行归一化
验证/预测阶段:
1. 将图像按比例Resize,使得最小边长度为crop_size[0] * 1.14
2. 从图像中心crop出一个大小为crop_size的图像
3. 将图像进行归一化
Args:
mode(str): 图像处理流程所处阶段,训练/验证/预测,分别对应'train', 'eval', 'test'
crop_size(int|list): 输入模型里的图像大小
mean(list): 图像均值
std(list): 图像方差
"""
def __init__(self,
mode,
crop_size=[224, 224],
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]):
width = crop_size
if isinstance(crop_size, list):
if crop_size[0] != crop_size[1]:
raise Exception(
"In classifier model, width and height should be equal, please modify your parameter `crop_size`"
)
width = crop_size[0]
if width % 32 != 0:
raise Exception(
"In classifier model, width and height should be multiple of 32, e.g 224、256、320...., please modify your parameter `crop_size`"
)
if mode == 'train':
pass
else:
# 验证/预测时的transforms
transforms = [
ResizeByShort(short_size=int(width * 1.14)),
CenterCrop(crop_size=width), Normalize(
mean=mean, std=std)
]
super(ComposedClsTransforms, self).__init__(transforms)
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import math
import numpy as np
from PIL import Image, ImageEnhance
def normalize(im, mean, std):
im = im / 255.0
im -= mean
im /= std
return im
def permute(im, to_bgr=False):
im = np.swapaxes(im, 1, 2)
im = np.swapaxes(im, 1, 0)
if to_bgr:
im = im[[2, 1, 0], :, :]
return im
def resize_long(im, long_size=224, interpolation=cv2.INTER_LINEAR):
value = max(im.shape[0], im.shape[1])
scale = float(long_size) / float(value)
resized_width = int(round(im.shape[1] * scale))
resized_height = int(round(im.shape[0] * scale))
im = cv2.resize(
im, (resized_width, resized_height), interpolation=interpolation)
return im
def resize(im, target_size=608, interp=cv2.INTER_LINEAR):
if isinstance(target_size, list) or isinstance(target_size, tuple):
w = target_size[0]
h = target_size[1]
else:
w = target_size
h = target_size
im = cv2.resize(im, (w, h), interpolation=interp)
return im
def random_crop(im,
crop_size=224,
lower_scale=0.08,
lower_ratio=3. / 4,
upper_ratio=4. / 3):
scale = [lower_scale, 1.0]
ratio = [lower_ratio, upper_ratio]
aspect_ratio = math.sqrt(np.random.uniform(*ratio))
w = 1. * aspect_ratio
h = 1. / aspect_ratio
bound = min((float(im.shape[0]) / im.shape[1]) / (h**2),
(float(im.shape[1]) / im.shape[0]) / (w**2))
scale_max = min(scale[1], bound)
scale_min = min(scale[0], bound)
target_area = im.shape[0] * im.shape[1] * np.random.uniform(
scale_min, scale_max)
target_size = math.sqrt(target_area)
w = int(target_size * w)
h = int(target_size * h)
i = np.random.randint(0, im.shape[0] - h + 1)
j = np.random.randint(0, im.shape[1] - w + 1)
im = im[i:i + h, j:j + w, :]
im = cv2.resize(im, (crop_size, crop_size))
return im
def center_crop(im, crop_size=224):
height, width = im.shape[:2]
w_start = (width - crop_size) // 2
h_start = (height - crop_size) // 2
w_end = w_start + crop_size
h_end = h_start + crop_size
im = im[h_start:h_end, w_start:w_end, :]
return im
def horizontal_flip(im):
if len(im.shape) == 3:
im = im[:, ::-1, :]
elif len(im.shape) == 2:
im = im[:, ::-1]
return im
def vertical_flip(im):
if len(im.shape) == 3:
im = im[::-1, :, :]
elif len(im.shape) == 2:
im = im[::-1, :]
return im
def bgr2rgb(im):
return im[:, :, ::-1]
def hue(im, hue_lower, hue_upper):
delta = np.random.uniform(hue_lower, hue_upper)
u = np.cos(delta * np.pi)
w = np.sin(delta * np.pi)
bt = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321],
[0.211, -0.523, 0.311]])
ityiq = np.array([[1.0, 0.956, 0.621], [1.0, -0.272, -0.647],
[1.0, -1.107, 1.705]])
t = np.dot(np.dot(ityiq, bt), tyiq).T
im = np.dot(im, t)
return im
def saturation(im, saturation_lower, saturation_upper):
delta = np.random.uniform(saturation_lower, saturation_upper)
gray = im * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32)
gray = gray.sum(axis=2, keepdims=True)
gray *= (1.0 - delta)
im *= delta
im += gray
return im
def contrast(im, contrast_lower, contrast_upper):
delta = np.random.uniform(contrast_lower, contrast_upper)
im *= delta
return im
def brightness(im, brightness_lower, brightness_upper):
delta = np.random.uniform(brightness_lower, brightness_upper)
im += delta
return im
def rotate(im, rotate_lower, rotate_upper):
rotate_delta = np.random.uniform(rotate_lower, rotate_upper)
im = im.rotate(int(rotate_delta))
return im
def resize_padding(im, max_side_len=2400):
'''
resize image to a size multiple of 32 which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
'''
h, w, _ = im.shape
resize_w = w
resize_h = h
# limit the max side
if max(resize_h, resize_w) > max_side_len:
ratio = float(
max_side_len) / resize_h if resize_h > resize_w else float(
max_side_len) / resize_w
else:
ratio = 1.
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
resize_h = resize_h if resize_h % 32 == 0 else (resize_h // 32 - 1) * 32
resize_w = resize_w if resize_w % 32 == 0 else (resize_w // 32 - 1) * 32
resize_h = max(32, resize_h)
resize_w = max(32, resize_w)
im = cv2.resize(im, (int(resize_w), int(resize_h)))
#im = cv2.resize(im, (512, 512))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
_ratio = np.array([ratio_h, ratio_w]).reshape(-1, 2)
return im, _ratio
此差异已折叠。
# download pre-compiled opencv lib
OPENCV_URL=https://paddleseg.bj.bcebos.com/deploy/docker/opencv3gcc4.8.tar.bz2
if [ ! -d "./deps/opencv3gcc4.8" ]; then
mkdir -p deps
cd deps
wget -c ${OPENCV_URL}
tar xvfj opencv3gcc4.8.tar.bz2
rm -rf opencv3gcc4.8.tar.bz2
cd ..
fi
# openvino预编译库的路径
OPENVINO_DIR=/path/to/inference_engine/
# gflags预编译库的路径
GFLAGS_DIR=/path/to/gflags
# OpenVINO预编译库的路径
OPENVINO_DIR=$INTEL_OPENVINO_DIR/inference_engine
# ngraph lib的路径,编译openvino时通常会生成
NGRAPH_LIB=/path/to/ngraph/lib/
NGRAPH_LIB=$INTEL_OPENVINO_DIR/deployment_tools/ngraph/lib
# gflags预编译库的路径
GFLAGS_DIR=$(pwd)/deps/gflags
# glog预编译库的路径
GLOG_DIR=$(pwd)/deps/glog
# opencv使用自带预编译版本
OPENCV_DIR=$(pwd)/deps/opencv/
#cpu架构
ARCH=x86
export ARCH
# opencv预编译库的路径, 如果使用自带预编译版本可不修改
OPENCV_DIR=$(pwd)/deps/opencv3gcc4.8/
# 下载自带预编译版本
sh $(pwd)/scripts/bootstrap.sh
#下载并编译third-part lib
sh $(pwd)/scripts/install_third-party.sh
rm -rf build
mkdir -p build
......@@ -16,6 +25,8 @@ cd build
cmake .. \
-DOPENCV_DIR=${OPENCV_DIR} \
-DGFLAGS_DIR=${GFLAGS_DIR} \
-DGLOG_DIR=${GLOG_DIR} \
-DOPENVINO_DIR=${OPENVINO_DIR} \
-DNGRAPH_LIB=${NGRAPH_LIB}
-DNGRAPH_LIB=${NGRAPH_LIB} \
-DARCH=${ARCH}
make
# download third-part lib
if [ ! -d "./deps" ]; then
mkdir deps
fi
if [ ! -d "./deps/gflag" ]; then
cd deps
git clone https://github.com/gflags/gflags
cd gflags
cmake .
make -j 8
cd ..
cd ..
fi
if [ ! -d "./deps/glog" ]; then
cd deps
git clone https://github.com/google/glog
sudo apt-get install autoconf automake libtool
cd glog
./autogen.sh
./configure
make -j 8
cd ..
cd ..
fi
if [ "$ARCH" = "x86" ]; then
OPENCV_URL=https://bj.bcebos.com/paddlex/deploy/x86opencv/opencv.tar.bz2
else
OPENCV_URL=https://bj.bcebos.com/paddlex/deploy/armopencv/opencv.tar.bz2
fi
if [ ! -d "./deps/opencv" ]; then
cd deps
wget -c ${OPENCV_URL}
tar xvfj opencv.tar.bz2
rm -rf opencv.tar.bz2
cd ..
fi
......@@ -13,28 +13,47 @@
// limitations under the License.
#include "include/paddlex/paddlex.h"
#include <iostream>
#include <fstream>
using namespace InferenceEngine;
namespace PaddleX {
void Model::create_predictor(const std::string& model_dir,
const std::string& cfg_dir,
const std::string& cfg_file,
std::string device) {
Core ie;
network_ = ie.ReadNetwork(model_dir, model_dir.substr(0, model_dir.size() - 4) + ".bin");
InferenceEngine::Core ie;
network_ = ie.ReadNetwork(
model_dir, model_dir.substr(0, model_dir.size() - 4) + ".bin");
network_.setBatchSize(1);
InputInfo::Ptr input_info = network_.getInputsInfo().begin()->second;
input_info->getPreProcess().setResizeAlgorithm(RESIZE_BILINEAR);
input_info->setLayout(Layout::NCHW);
input_info->setPrecision(Precision::FP32);
executable_network_ = ie.LoadNetwork(network_, device);
load_config(cfg_dir);
InferenceEngine::InputsDataMap inputInfo(network_.getInputsInfo());
std::string imageInputName;
for (const auto & inputInfoItem : inputInfo) {
if (inputInfoItem.second->getTensorDesc().getDims().size() == 4) {
imageInputName = inputInfoItem.first;
inputInfoItem.second->setPrecision(InferenceEngine::Precision::FP32);
inputInfoItem.second->getPreProcess().setResizeAlgorithm(
InferenceEngine::RESIZE_BILINEAR);
inputInfoItem.second->setLayout(InferenceEngine::Layout::NCHW);
}
if (inputInfoItem.second->getTensorDesc().getDims().size() == 2) {
imageInputName = inputInfoItem.first;
inputInfoItem.second->setPrecision(InferenceEngine::Precision::FP32);
}
}
if (device == "MYRIAD") {
std::map<std::string, std::string> networkConfig;
networkConfig["VPU_HW_STAGES_OPTIMIZATION"] = "ON";
executable_network_ = ie.LoadNetwork(network_, device, networkConfig);
} else {
executable_network_ = ie.LoadNetwork(network_, device);
}
load_config(cfg_file);
}
bool Model::load_config(const std::string& cfg_dir) {
YAML::Node config = YAML::LoadFile(cfg_dir);
bool Model::load_config(const std::string& cfg_file) {
YAML::Node config = YAML::LoadFile(cfg_file);
type = config["_Attributes"]["model_type"].as<std::string>();
name = config["Model"].as<std::string>();
bool to_rgb = true;
......@@ -48,22 +67,26 @@ bool Model::load_config(const std::string& cfg_dir) {
return false;
}
}
// 构建数据处理流
transforms_.Init(config["Transforms"], to_rgb);
// 读入label list
labels.clear();
labels = config["_Attributes"]["labels"].as<std::vector<std::string>>();
// init preprocess ops
transforms_.Init(config["Transforms"], type, to_rgb);
// read label list
for (const auto& item : config["_Attributes"]["labels"]) {
int index = labels.size();
labels[index] = item.as<std::string>();
}
return true;
}
bool Model::preprocess(cv::Mat* input_im) {
if (!transforms_.Run(input_im, inputs_)) {
bool Model::preprocess(cv::Mat* input_im, ImageBlob* inputs) {
if (!transforms_.Run(input_im, inputs)) {
return false;
}
return true;
}
bool Model::predict(const cv::Mat& im, ClsResult* result) {
inputs_.clear();
if (type == "detector") {
std::cerr << "Loading model is a 'detector', DetResult should be passed to "
"function predict()!"
......@@ -75,34 +98,221 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) {
<< std::endl;
return false;
}
// 处理输入图像
InferRequest infer_request = executable_network_.CreateInferRequest();
// preprocess
InferenceEngine::InferRequest infer_request =
executable_network_.CreateInferRequest();
std::string input_name = network_.getInputsInfo().begin()->first;
inputs_ = infer_request.GetBlob(input_name);
auto im_clone = im.clone();
if (!preprocess(&im_clone)) {
inputs_.blob = infer_request.GetBlob(input_name);
cv::Mat im_clone = im.clone();
if (!preprocess(&im_clone, &inputs_)) {
std::cerr << "Preprocess failed!" << std::endl;
return false;
}
// predict
infer_request.Infer();
std::string output_name = network_.getOutputsInfo().begin()->first;
output_ = infer_request.GetBlob(output_name);
MemoryBlob::CPtr moutput = as<MemoryBlob>(output_);
InferenceEngine::MemoryBlob::CPtr moutput =
InferenceEngine::as<InferenceEngine::MemoryBlob>(output_);
auto moutputHolder = moutput->rmap();
float* outputs_data = moutputHolder.as<float *>();
// 对模型输出结果进行后处理
// post process
auto ptr = std::max_element(outputs_data, outputs_data+sizeof(outputs_data));
result->category_id = std::distance(outputs_data, ptr);
result->score = *ptr;
result->category = labels[result->category_id];
//for (int i=0;i<sizeof(outputs_data);i++){
// std::cout << labels[i] << std::endl;
// std::cout << outputs_[i] << std::endl;
// }
return true;
}
bool Model::predict(const cv::Mat& im, DetResult* result) {
inputs_.clear();
result->clear();
if (type == "classifier") {
std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
"to function predict()!" << std::endl;
return false;
} else if (type == "segmenter") {
std::cerr << "Loading model is a 'segmenter', SegResult should be passed "
"to function predict()!" << std::endl;
return false;
}
InferenceEngine::InferRequest infer_request =
executable_network_.CreateInferRequest();
InferenceEngine::InputsDataMap input_maps = network_.getInputsInfo();
std::string inputName;
for (const auto & input_map : input_maps) {
if (input_map.second->getTensorDesc().getDims().size() == 4) {
inputName = input_map.first;
inputs_.blob = infer_request.GetBlob(inputName);
}
if (input_map.second->getTensorDesc().getDims().size() == 2) {
inputName = input_map.first;
inputs_.ori_im_size_ = infer_request.GetBlob(inputName);
}
}
cv::Mat im_clone = im.clone();
if (!preprocess(&im_clone, &inputs_)) {
std::cerr << "Preprocess failed!" << std::endl;
return false;
}
infer_request.Infer();
InferenceEngine::OutputsDataMap out_map = network_.getOutputsInfo();
auto iter = out_map.begin();
std::string outputName = iter->first;
InferenceEngine::Blob::Ptr output = infer_request.GetBlob(outputName);
InferenceEngine::MemoryBlob::CPtr moutput =
InferenceEngine::as<InferenceEngine::MemoryBlob>(output);
InferenceEngine::TensorDesc blob_output = moutput->getTensorDesc();
std::vector<size_t> output_shape = blob_output.getDims();
auto moutputHolder = moutput->rmap();
float* data = moutputHolder.as<float *>();
int size = 1;
for (auto& i : output_shape) {
size *= static_cast<int>(i);
}
int num_boxes = size / 6;
for (int i = 0; i < num_boxes; ++i) {
if (data[i * 6] > 0) {
Box box;
box.category_id = static_cast<int>(data[i * 6]);
box.category = labels[box.category_id];
box.score = data[i * 6 + 1];
float xmin = data[i * 6 + 2];
float ymin = data[i * 6 + 3];
float xmax = data[i * 6 + 4];
float ymax = data[i * 6 + 5];
float w = xmax - xmin + 1;
float h = ymax - ymin + 1;
box.coordinate = {xmin, ymin, w, h};
result->boxes.push_back(std::move(box));
}
}
}
} // namespce of PaddleX
bool Model::predict(const cv::Mat& im, SegResult* result) {
result->clear();
inputs_.clear();
if (type == "classifier") {
std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
"to function predict()!" << std::endl;
return false;
} else if (type == "detector") {
std::cerr << "Loading model is a 'detector', DetResult should be passed to "
"function predict()!" << std::endl;
return false;
}
// init infer
InferenceEngine::InferRequest infer_request =
executable_network_.CreateInferRequest();
std::string input_name = network_.getInputsInfo().begin()->first;
inputs_.blob = infer_request.GetBlob(input_name);
// preprocess
cv::Mat im_clone = im.clone();
if (!preprocess(&im_clone, &inputs_)) {
std::cerr << "Preprocess failed!" << std::endl;
return false;
}
// predict
infer_request.Infer();
InferenceEngine::OutputsDataMap out_map = network_.getOutputsInfo();
auto iter = out_map.begin();
iter++;
std::string output_name_score = iter->first;
InferenceEngine::Blob::Ptr output_score =
infer_request.GetBlob(output_name_score);
InferenceEngine::MemoryBlob::CPtr moutput_score =
InferenceEngine::as<InferenceEngine::MemoryBlob>(output_score);
InferenceEngine::TensorDesc blob_score = moutput_score->getTensorDesc();
std::vector<size_t> output_score_shape = blob_score.getDims();
int size = 1;
for (auto& i : output_score_shape) {
size *= static_cast<int>(i);
result->score_map.shape.push_back(static_cast<int>(i));
}
result->score_map.data.resize(size);
auto moutputHolder_score = moutput_score->rmap();
float* score_data = moutputHolder_score.as<float *>();
memcpy(result->score_map.data.data(), score_data, moutput_score->byteSize());
iter++;
std::string output_name_label = iter->first;
InferenceEngine::Blob::Ptr output_label =
infer_request.GetBlob(output_name_label);
InferenceEngine::MemoryBlob::CPtr moutput_label =
InferenceEngine::as<InferenceEngine::MemoryBlob>(output_label);
InferenceEngine::TensorDesc blob_label = moutput_label->getTensorDesc();
std::vector<size_t> output_label_shape = blob_label.getDims();
size = 1;
for (auto& i : output_label_shape) {
size *= static_cast<int>(i);
result->label_map.shape.push_back(static_cast<int>(i));
}
result->label_map.data.resize(size);
auto moutputHolder_label = moutput_label->rmap();
int* label_data = moutputHolder_label.as<int *>();
memcpy(result->label_map.data.data(), label_data, moutput_label->byteSize());
std::vector<uint8_t> label_map(result->label_map.data.begin(),
result->label_map.data.end());
cv::Mat mask_label(result->label_map.shape[1],
result->label_map.shape[2],
CV_8UC1,
label_map.data());
cv::Mat mask_score(result->score_map.shape[2],
result->score_map.shape[3],
CV_32FC1,
result->score_map.data.data());
int idx = 1;
int len_postprocess = inputs_.im_size_before_resize_.size();
for (std::vector<std::string>::reverse_iterator iter =
inputs_.reshape_order_.rbegin();
iter != inputs_.reshape_order_.rend();
++iter) {
if (*iter == "padding") {
auto before_shape = inputs_.im_size_before_resize_[len_postprocess - idx];
inputs_.im_size_before_resize_.pop_back();
auto padding_w = before_shape[0];
auto padding_h = before_shape[1];
mask_label = mask_label(cv::Rect(0, 0, padding_h, padding_w));
mask_score = mask_score(cv::Rect(0, 0, padding_h, padding_w));
} else if (*iter == "resize") {
auto before_shape = inputs_.im_size_before_resize_[len_postprocess - idx];
inputs_.im_size_before_resize_.pop_back();
auto resize_w = before_shape[0];
auto resize_h = before_shape[1];
cv::resize(mask_label,
mask_label,
cv::Size(resize_h, resize_w),
0,
0,
cv::INTER_NEAREST);
cv::resize(mask_score,
mask_score,
cv::Size(resize_h, resize_w),
0,
0,
cv::INTER_LINEAR);
}
++idx;
}
result->label_map.data.assign(mask_label.begin<uint8_t>(),
mask_label.end<uint8_t>());
result->label_map.shape = {mask_label.rows, mask_label.cols};
result->score_map.data.assign(mask_score.begin<float>(),
mask_score.end<float>());
result->score_map.shape = {mask_score.rows, mask_score.cols};
return true;
}
} // namespace PaddleX
......@@ -12,11 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "include/paddlex/transforms.h"
#include <math.h>
#include <iostream>
#include <fstream>
#include <string>
#include <vector>
#include "include/paddlex/transforms.h"
namespace PaddleX {
......@@ -26,7 +30,7 @@ std::map<std::string, int> interpolations = {{"LINEAR", cv::INTER_LINEAR},
{"CUBIC", cv::INTER_CUBIC},
{"LANCZOS4", cv::INTER_LANCZOS4}};
bool Normalize::Run(cv::Mat* im){
bool Normalize::Run(cv::Mat* im, ImageBlob* data) {
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
......@@ -40,19 +44,6 @@ bool Normalize::Run(cv::Mat* im){
return true;
}
bool CenterCrop::Run(cv::Mat* im) {
int height = static_cast<int>(im->rows);
int width = static_cast<int>(im->cols);
if (height < height_ || width < width_) {
std::cerr << "[CenterCrop] Image size less than crop size" << std::endl;
return false;
}
int offset_x = static_cast<int>((width - width_) / 2);
int offset_y = static_cast<int>((height - height_) / 2);
cv::Rect crop_roi(offset_x, offset_y, width_, height_);
*im = (*im)(crop_roi);
return true;
}
float ResizeByShort::GenerateScale(const cv::Mat& im) {
......@@ -70,17 +61,115 @@ float ResizeByShort::GenerateScale(const cv::Mat& im) {
return scale;
}
bool ResizeByShort::Run(cv::Mat* im) {
bool ResizeByShort::Run(cv::Mat* im, ImageBlob* data) {
data->im_size_before_resize_.push_back({im->rows, im->cols});
data->reshape_order_.push_back("resize");
float scale = GenerateScale(*im);
int width = static_cast<int>(scale * im->cols);
int height = static_cast<int>(scale * im->rows);
int width = static_cast<int>(round(scale * im->cols));
int height = static_cast<int>(round(scale * im->rows));
cv::resize(*im, *im, cv::Size(width, height), 0, 0, cv::INTER_LINEAR);
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
data->scale = scale;
return true;
}
void Transforms::Init(const YAML::Node& transforms_node, bool to_rgb) {
bool CenterCrop::Run(cv::Mat* im, ImageBlob* data) {
int height = static_cast<int>(im->rows);
int width = static_cast<int>(im->cols);
if (height < height_ || width < width_) {
std::cerr << "[CenterCrop] Image size less than crop size" << std::endl;
return false;
}
int offset_x = static_cast<int>((width - width_) / 2);
int offset_y = static_cast<int>((height - height_) / 2);
cv::Rect crop_roi(offset_x, offset_y, width_, height_);
*im = (*im)(crop_roi);
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
return true;
}
bool Padding::Run(cv::Mat* im, ImageBlob* data) {
data->im_size_before_resize_.push_back({im->rows, im->cols});
data->reshape_order_.push_back("padding");
int padding_w = 0;
int padding_h = 0;
if (width_ > 1 & height_ > 1) {
padding_w = width_ - im->cols;
padding_h = height_ - im->rows;
} else if (coarsest_stride_ >= 1) {
int h = im->rows;
int w = im->cols;
padding_h =
ceil(h * 1.0 / coarsest_stride_) * coarsest_stride_ - im->rows;
padding_w =
ceil(w * 1.0 / coarsest_stride_) * coarsest_stride_ - im->cols;
}
if (padding_h < 0 || padding_w < 0) {
std::cerr << "[Padding] Computed padding_h=" << padding_h
<< ", padding_w=" << padding_w
<< ", but they should be greater than 0." << std::endl;
return false;
}
cv::Scalar value = cv::Scalar(im_value_[0], im_value_[1], im_value_[2]);
cv::copyMakeBorder(
*im, *im, 0, padding_h, 0, padding_w, cv::BORDER_CONSTANT, value);
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
return true;
}
bool ResizeByLong::Run(cv::Mat* im, ImageBlob* data) {
if (long_size_ <= 0) {
std::cerr << "[ResizeByLong] long_size should be greater than 0"
<< std::endl;
return false;
}
data->im_size_before_resize_.push_back({im->rows, im->cols});
data->reshape_order_.push_back("resize");
int origin_w = im->cols;
int origin_h = im->rows;
int im_size_max = std::max(origin_w, origin_h);
float scale =
static_cast<float>(long_size_) / static_cast<float>(im_size_max);
cv::resize(*im, *im, cv::Size(), scale, scale, cv::INTER_NEAREST);
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
data->scale = scale;
return true;
}
bool Resize::Run(cv::Mat* im, ImageBlob* data) {
if (width_ <= 0 || height_ <= 0) {
std::cerr << "[Resize] width and height should be greater than 0"
<< std::endl;
return false;
}
if (interpolations.count(interp_) <= 0) {
std::cerr << "[Resize] Invalid interpolation method: '" << interp_ << "'"
<< std::endl;
return false;
}
data->im_size_before_resize_.push_back({im->rows, im->cols});
data->reshape_order_.push_back("resize");
cv::resize(
*im, *im, cv::Size(width_, height_), 0, 0, interpolations[interp_]);
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
return true;
}
void Transforms::Init(
const YAML::Node& transforms_node, std::string type, bool to_rgb) {
transforms_.clear();
to_rgb_ = to_rgb;
type_ = type;
for (const auto& item : transforms_node) {
std::string name = item.begin()->first.as<std::string>();
std::cout << "trans name: " << name << std::endl;
......@@ -94,10 +183,16 @@ std::shared_ptr<Transform> Transforms::CreateTransform(
const std::string& transform_name) {
if (transform_name == "Normalize") {
return std::make_shared<Normalize>();
} else if (transform_name == "CenterCrop") {
return std::make_shared<CenterCrop>();
} else if (transform_name == "ResizeByShort") {
return std::make_shared<ResizeByShort>();
} else if (transform_name == "CenterCrop") {
return std::make_shared<CenterCrop>();
} else if (transform_name == "Resize") {
return std::make_shared<Resize>();
} else if (transform_name == "Padding") {
return std::make_shared<Padding>();
} else if (transform_name == "ResizeByLong") {
return std::make_shared<ResizeByLong>();
} else {
std::cerr << "There's unexpected transform(name='" << transform_name
<< "')." << std::endl;
......@@ -105,27 +200,38 @@ std::shared_ptr<Transform> Transforms::CreateTransform(
}
}
bool Transforms::Run(cv::Mat* im, Blob::Ptr blob) {
// 按照transforms中预处理算子顺序处理图像
bool Transforms::Run(cv::Mat* im, ImageBlob* data) {
// preprocess by order
if (to_rgb_) {
cv::cvtColor(*im, *im, cv::COLOR_BGR2RGB);
}
(*im).convertTo(*im, CV_32FC3);
if (type_ == "detector") {
InferenceEngine::LockedMemory<void> input2Mapped =
InferenceEngine::as<InferenceEngine::MemoryBlob>(
data->ori_im_size_)->wmap();
float *p = input2Mapped.as<float*>();
p[0] = im->rows;
p[1] = im->cols;
}
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
for (int i = 0; i < transforms_.size(); ++i) {
if (!transforms_[i]->Run(im)) {
if (!transforms_[i]->Run(im, data)) {
std::cerr << "Apply transforms to image failed!" << std::endl;
return false;
}
}
// 将图像由NHWC转为NCHW格式
// 同时转为连续的内存块存储到Blob
SizeVector blobSize = blob->getTensorDesc().getDims();
// image format NHWC to NCHW
// img data save to ImageBlob
InferenceEngine::SizeVector blobSize = data->blob->getTensorDesc().getDims();
const size_t width = blobSize[3];
const size_t height = blobSize[2];
const size_t channels = blobSize[1];
MemoryBlob::Ptr mblob = InferenceEngine::as<MemoryBlob>(blob);
InferenceEngine::MemoryBlob::Ptr mblob =
InferenceEngine::as<InferenceEngine::MemoryBlob>(data->blob);
auto mblobHolder = mblob->wmap();
float *blob_data = mblobHolder.as<float *>();
for (size_t c = 0; c < channels; c++) {
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "include/paddlex/visualize.h"
namespace PaddleX {
std::vector<int> GenerateColorMap(int num_class) {
auto colormap = std::vector<int>(3 * num_class, 0);
for (int i = 0; i < num_class; ++i) {
int j = 0;
int lab = i;
while (lab) {
colormap[i * 3] |= (((lab >> 0) & 1) << (7 - j));
colormap[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j));
colormap[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j));
++j;
lab >>= 3;
}
}
return colormap;
}
cv::Mat Visualize(const cv::Mat& img,
const DetResult& result,
const std::map<int, std::string>& labels,
const std::vector<int>& colormap,
float threshold) {
cv::Mat vis_img = img.clone();
auto boxes = result.boxes;
for (int i = 0; i < boxes.size(); ++i) {
if (boxes[i].score < threshold) {
continue;
}
cv::Rect roi = cv::Rect(boxes[i].coordinate[0],
boxes[i].coordinate[1],
boxes[i].coordinate[2],
boxes[i].coordinate[3]);
// draw box and title
std::string text = boxes[i].category;
int c1 = colormap[3 * boxes[i].category_id + 0];
int c2 = colormap[3 * boxes[i].category_id + 1];
int c3 = colormap[3 * boxes[i].category_id + 2];
cv::Scalar roi_color = cv::Scalar(c1, c2, c3);
text += std::to_string(static_cast<int>(boxes[i].score * 100)) + "%";
int font_face = cv::FONT_HERSHEY_SIMPLEX;
double font_scale = 0.5f;
float thickness = 0.5;
cv::Size text_size =
cv::getTextSize(text, font_face, font_scale, thickness, nullptr);
cv::Point origin;
origin.x = roi.x;
origin.y = roi.y;
// background
cv::Rect text_back = cv::Rect(boxes[i].coordinate[0],
boxes[i].coordinate[1] - text_size.height,
text_size.width,
text_size.height);
// draw
cv::rectangle(vis_img, roi, roi_color, 2);
cv::rectangle(vis_img, text_back, roi_color, -1);
cv::putText(vis_img,
text,
origin,
font_face,
font_scale,
cv::Scalar(255, 255, 255),
thickness);
// mask
if (boxes[i].mask.data.size() == 0) {
continue;
}
cv::Mat bin_mask(result.mask_resolution,
result.mask_resolution,
CV_32FC1,
boxes[i].mask.data.data());
cv::resize(bin_mask,
bin_mask,
cv::Size(boxes[i].mask.shape[0], boxes[i].mask.shape[1]));
cv::threshold(bin_mask, bin_mask, 0.5, 1, cv::THRESH_BINARY);
cv::Mat full_mask = cv::Mat::zeros(vis_img.size(), CV_8UC1);
bin_mask.copyTo(full_mask(roi));
cv::Mat mask_ch[3];
mask_ch[0] = full_mask * c1;
mask_ch[1] = full_mask * c2;
mask_ch[2] = full_mask * c3;
cv::Mat mask;
cv::merge(mask_ch, 3, mask);
cv::addWeighted(vis_img, 1, mask, 0.5, 0, vis_img);
}
return vis_img;
}
cv::Mat Visualize(const cv::Mat& img,
const SegResult& result,
const std::map<int, std::string>& labels,
const std::vector<int>& colormap) {
std::vector<uint8_t> label_map(result.label_map.data.begin(),
result.label_map.data.end());
cv::Mat mask(result.label_map.shape[0],
result.label_map.shape[1],
CV_8UC1,
label_map.data());
cv::Mat color_mask = cv::Mat::zeros(
result.label_map.shape[0], result.label_map.shape[1], CV_8UC3);
int rows = img.rows;
int cols = img.cols;
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
int category_id = static_cast<int>(mask.at<uchar>(i, j));
color_mask.at<cv::Vec3b>(i, j)[0] = colormap[3 * category_id + 0];
color_mask.at<cv::Vec3b>(i, j)[1] = colormap[3 * category_id + 1];
color_mask.at<cv::Vec3b>(i, j)[2] = colormap[3 * category_id + 2];
}
}
return color_mask;
}
std::string generate_save_path(const std::string& save_dir,
const std::string& file_path) {
if (access(save_dir.c_str(), 0) < 0) {
#ifdef _WIN32
mkdir(save_dir.c_str());
#else
if (mkdir(save_dir.c_str(), S_IRWXU) < 0) {
std::cerr << "Fail to create " << save_dir << "directory." << std::endl;
}
#endif
}
int pos = file_path.find_last_of(OS_PATH_SEP);
std::string image_name(file_path.substr(pos + 1));
return save_dir + OS_PATH_SEP + image_name;
}
} // namespace PaddleX
cmake_minimum_required(VERSION 3.0)
project(PaddleX CXX C)
option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." OFF)
SET(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake" ${CMAKE_MODULE_PATH})
SET(LITE_DIR "" CACHE PATH "Location of libraries")
SET(OPENCV_DIR "" CACHE PATH "Location of libraries")
SET(NGRAPH_LIB "" CACHE PATH "Location of libraries")
include(cmake/yaml-cpp.cmake)
include_directories("${CMAKE_SOURCE_DIR}/")
link_directories("${CMAKE_CURRENT_BINARY_DIR}")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/ext/yaml-cpp/src/ext-yaml-cpp/include")
link_directories("${CMAKE_CURRENT_BINARY_DIR}/ext/yaml-cpp/lib")
macro(safe_set_static_flag)
foreach(flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO)
if(${flag_var} MATCHES "/MD")
string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}")
endif(${flag_var} MATCHES "/MD")
endforeach(flag_var)
endmacro()
if (NOT DEFINED LITE_DIR OR ${LITE_DIR} STREQUAL "")
message(FATAL_ERROR "please set LITE_DIR with -LITE_DIR=/path/influence_engine")
endif()
if (NOT DEFINED OPENCV_DIR OR ${OPENCV_DIR} STREQUAL "")
message(FATAL_ERROR "please set OPENCV_DIR with -DOPENCV_DIR=/path/opencv")
endif()
if (NOT DEFINED GFLAGS_DIR OR ${GFLAGS_DIR} STREQUAL "")
message(FATAL_ERROR "please set GFLAGS_DIR with -DGFLAGS_DIR=/path/gflags")
endif()
link_directories("${LITE_DIR}/lib")
include_directories("${LITE_DIR}/include")
link_directories("${GFLAGS_DIR}/lib")
include_directories("${GFLAGS_DIR}/include")
if (WIN32)
find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/build/ NO_DEFAULT_PATH)
unset(OpenCV_DIR CACHE)
else ()
find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/cmake NO_DEFAULT_PATH)
endif ()
include_directories(${OpenCV_INCLUDE_DIRS})
if (WIN32)
add_definitions("/DGOOGLE_GLOG_DLL_DECL=")
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /bigobj /MTd")
set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /bigobj /MT")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /bigobj /MTd")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /bigobj /MT")
if (WITH_STATIC_LIB)
safe_set_static_flag()
add_definitions(-DSTATIC_LIB)
endif()
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfloat-abi=hard -mfpu=neon-vfpv4 -g -o2 -fopenmp -std=c++11")
set(CMAKE_STATIC_LIBRARY_PREFIX "")
endif()
if(WITH_STATIC_LIB)
set(DEPS ${LITE_DIR}/lib/libpaddle_full_api_shared${CMAKE_STATIC_LIBRARY_SUFFIX})
else()
set(DEPS ${LITE_DIR}/lib/libpaddle_full_api_shared${CMAKE_SHARED_LIBRARY_SUFFIX})
endif()
if (NOT WIN32)
set(DEPS ${DEPS}
glog gflags z yaml-cpp
)
else()
set(DEPS ${DEPS}
glog gflags_static libprotobuf zlibstatic xxhash libyaml-cppmt)
set(DEPS ${DEPS} libcmt shlwapi)
endif(NOT WIN32)
if (NOT WIN32)
set(EXTERNAL_LIB "-ldl -lrt -lgomp -lz -lm -lpthread")
set(DEPS ${DEPS} ${EXTERNAL_LIB})
endif()
set(DEPS ${DEPS} ${OpenCV_LIBS})
add_executable(classifier demo/classifier.cpp src/transforms.cpp src/paddlex.cpp)
ADD_DEPENDENCIES(classifier ext-yaml-cpp)
target_link_libraries(classifier ${DEPS})
add_executable(segmenter demo/segmenter.cpp src/transforms.cpp src/paddlex.cpp src/visualize.cpp)
ADD_DEPENDENCIES(segmenter ext-yaml-cpp)
target_link_libraries(segmenter ${DEPS})
add_executable(detector demo/detector.cpp src/transforms.cpp src/paddlex.cpp src/visualize.cpp)
ADD_DEPENDENCIES(detector ext-yaml-cpp)
target_link_libraries(detector ${DEPS})
include(ExternalProject)
message("${CMAKE_BUILD_TYPE}")
ExternalProject_Add(
ext-yaml-cpp
URL https://bj.bcebos.com/paddlex/deploy/deps/yaml-cpp.zip
URL_MD5 9542d6de397d1fbd649ed468cb5850e6
CMAKE_ARGS
-DYAML_CPP_BUILD_TESTS=OFF
-DYAML_CPP_BUILD_TOOLS=OFF
-DYAML_CPP_INSTALL=OFF
-DYAML_CPP_BUILD_CONTRIB=OFF
-DMSVC_SHARED_RT=OFF
-DBUILD_SHARED_LIBS=OFF
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=${CMAKE_BINARY_DIR}/ext/yaml-cpp/lib
-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY=${CMAKE_BINARY_DIR}/ext/yaml-cpp/lib
PREFIX "${CMAKE_BINARY_DIR}/ext/yaml-cpp"
# Disable install step
INSTALL_COMMAND ""
LOG_DOWNLOAD ON
LOG_BUILD 1
)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include "include/paddlex/paddlex.h"
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_string(cfg_file, "", "Path of PaddelX model yml file");
DEFINE_string(image, "", "Path of test image file");
DEFINE_string(image_list, "", "Path of test image list file");
DEFINE_int32(thread_num, 1, "num of thread to infer");
int main(int argc, char** argv) {
// Parsing command-line
google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_model_dir == "") {
std::cerr << "--model_dir need to be defined" << std::endl;
return -1;
}
if (FLAGS_cfg_file == "") {
std::cerr << "--cfg_flie need to be defined" << std::endl;
return -1;
}
if (FLAGS_image == "" & FLAGS_image_list == "") {
std::cerr << "--image or --image_list need to be defined" << std::endl;
return -1;
}
// load model
PaddleX::Model model;
model.Init(FLAGS_model_dir, FLAGS_cfg_file, FLAGS_thread_num);
std::cout << "init is done" << std::endl;
// predict
if (FLAGS_image_list != "") {
std::ifstream inf(FLAGS_image_list);
if (!inf) {
std::cerr << "Fail to open file " << FLAGS_image_list << std::endl;
return -1;
}
std::string image_path;
while (getline(inf, image_path)) {
PaddleX::ClsResult result;
cv::Mat im = cv::imread(image_path, 1);
model.predict(im, &result);
std::cout << "Predict label: " << result.category
<< ", label_id:" << result.category_id
<< ", score: " << result.score << std::endl;
}
} else {
PaddleX::ClsResult result;
cv::Mat im = cv::imread(FLAGS_image, 1);
model.predict(im, &result);
std::cout << "Predict label: " << result.category
<< ", label_id:" << result.category_id
<< ", score: " << result.score << std::endl;
}
return 0;
}
此差异已折叠。
此差异已折叠。
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iostream>
#include <map>
#include <string>
#include <vector>
#include "yaml-cpp/yaml.h"
#ifdef _WIN32
#define OS_PATH_SEP "\\"
#else
#define OS_PATH_SEP "/"
#endif
namespace PaddleX {
// Inference model configuration parser
class ConfigPaser {
public:
ConfigPaser() {}
~ConfigPaser() {}
bool load_config(const std::string& model_dir,
const std::string& cfg = "model.yml") {
// Load as a YAML::Node
YAML::Node config;
config = YAML::LoadFile(model_dir + OS_PATH_SEP + cfg);
if (config["Transforms"].IsDefined()) {
YAML::Node transforms_ = config["Transforms"];
} else {
std::cerr << "There's no field 'Transforms' in model.yml" << std::endl;
return false;
}
return true;
}
YAML::Node Transforms_;
};
} // namespace PaddleX
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -6,6 +6,8 @@ OpenVINO部署
:maxdepth: 2
:caption: 文档目录:
introduction.md
windows.md
linux.md
intel_movidius.md
python.md
export_openvino_model.md
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册