提交 646c3d16 编写于 作者: S syyxsxx

add raspberry support

上级 36803267
文件已添加
文件已添加
......@@ -105,10 +105,7 @@ if (NOT WIN32)
endif()
set(DEPS ${DEPS} ${OpenCV_LIBS})
add_executable(classifier demo/classifier.cpp src/transforms.cpp src/paddlex.cpp)
add_executable(classifier src/classifier.cpp src/transforms.cpp src/paddlex.cpp)
ADD_DEPENDENCIES(classifier ext-yaml-cpp)
target_link_libraries(classifier ${DEPS})
add_executable(segmenter demo/segmenter.cpp src/transforms.cpp src/paddlex.cpp src/visualize.cpp)
ADD_DEPENDENCIES(segmenter ext-yaml-cpp)
target_link_libraries(segmenter ${DEPS})
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
......@@ -37,9 +37,9 @@ namespace PaddleX {
class Model {
public:
void Init(const std::string& model_dir,
const std::string& cfg_file,
const std::string& cfg_dir,
std::string device) {
create_predictor(model_dir, cfg_file, device);
create_predictor(model_dir, cfg_dir, device);
}
void create_predictor(const std::string& model_dir,
......@@ -48,18 +48,15 @@ class Model {
bool load_config(const std::string& model_dir);
bool preprocess(cv::Mat* input_im, ImageBlob* inputs);
bool preprocess(cv::Mat* input_im);
bool predict(const cv::Mat& im, ClsResult* result);
bool predict(const cv::Mat& im, SegResult* result);
std::string type;
std::string name;
std::map<int, std::string> labels;
std::vector<std::string> labels;
Transforms transforms_;
ImageBlob inputs_;
Blob::Ptr inputs_;
Blob::Ptr output_;
CNNNetwork network_;
ExecutableNetwork executable_network_;
......
......@@ -61,7 +61,7 @@ class DetResult : public BaseResult {
class SegResult : public BaseResult {
public:
Mask<int> label_map;
Mask<int64_t> label_map;
Mask<float> score_map;
void clear() {
label_map.clear();
......
......@@ -31,38 +31,11 @@ using namespace InferenceEngine;
namespace PaddleX {
/*
* @brief
* This class represents object for storing all preprocessed data
* */
class ImageBlob {
public:
// Original image height and width
std::vector<int> ori_im_size_ = std::vector<int>(2);
// Newest image height and width after process
std::vector<int> new_im_size_ = std::vector<int>(2);
// Image height and width before resize
std::vector<std::vector<int>> im_size_before_resize_;
// Reshape order
std::vector<std::string> reshape_order_;
// Resize scale
float scale = 1.0;
// Buffer for image data after preprocessing
Blob::Ptr blob;
void clear() {
im_size_before_resize_.clear();
reshape_order_.clear();
}
};
// Abstraction of preprocessing opration class
class Transform {
public:
virtual void Init(const YAML::Node& item) = 0;
virtual bool Run(cv::Mat* im, ImageBlob* data) = 0;
virtual bool Run(cv::Mat* im) = 0;
};
class Normalize : public Transform {
......@@ -72,7 +45,7 @@ class Normalize : public Transform {
std_ = item["std"].as<std::vector<float>>();
}
virtual bool Run(cv::Mat* im, ImageBlob* data);
virtual bool Run(cv::Mat* im);
private:
std::vector<float> mean_;
......@@ -89,7 +62,7 @@ class ResizeByShort : public Transform {
max_size_ = -1;
}
};
virtual bool Run(cv::Mat* im, ImageBlob* data);
virtual bool Run(cv::Mat* im);
private:
float GenerateScale(const cv::Mat& im);
......@@ -97,55 +70,6 @@ class ResizeByShort : public Transform {
int max_size_;
};
/*
* @brief
* This class execute resize by long operation on image matrix. At first, it resizes
* the long side of image matrix to specified length. Accordingly, the short side
* will be resized in the same proportion.
* */
class ResizeByLong : public Transform {
public:
virtual void Init(const YAML::Node& item) {
long_size_ = item["long_size"].as<int>();
}
virtual bool Run(cv::Mat* im, ImageBlob* data);
private:
int long_size_;
};
/*
* @brief
* This class execute resize operation on image matrix. It resizes width and height
* to specified length.
* */
class Resize : public Transform {
public:
virtual void Init(const YAML::Node& item) {
if (item["interp"].IsDefined()) {
interp_ = item["interp"].as<std::string>();
}
if (item["target_size"].IsScalar()) {
height_ = item["target_size"].as<int>();
width_ = item["target_size"].as<int>();
} else if (item["target_size"].IsSequence()) {
std::vector<int> target_size = item["target_size"].as<std::vector<int>>();
width_ = target_size[0];
height_ = target_size[1];
}
if (height_ <= 0 || width_ <= 0) {
std::cerr << "[Resize] target_size should greater than 0" << std::endl;
exit(-1);
}
}
virtual bool Run(cv::Mat* im, ImageBlob* data);
private:
int height_;
int width_;
std::string interp_;
};
class CenterCrop : public Transform {
public:
......@@ -159,53 +83,18 @@ class CenterCrop : public Transform {
height_ = crop_size[1];
}
}
virtual bool Run(cv::Mat* im, ImageBlob* data);
virtual bool Run(cv::Mat* im);
private:
int height_;
int width_;
};
/*
* @brief
* This class execute padding operation on image matrix. It makes border on edge
* of image matrix.
* */
class Padding : public Transform {
public:
virtual void Init(const YAML::Node& item) {
if (item["coarsest_stride"].IsDefined()) {
coarsest_stride_ = item["coarsest_stride"].as<int>();
if (coarsest_stride_ < 1) {
std::cerr << "[Padding] coarest_stride should greater than 0"
<< std::endl;
exit(-1);
}
}
if (item["target_size"].IsDefined()) {
if (item["target_size"].IsScalar()) {
width_ = item["target_size"].as<int>();
height_ = item["target_size"].as<int>();
} else if (item["target_size"].IsSequence()) {
width_ = item["target_size"].as<std::vector<int>>()[0];
height_ = item["target_size"].as<std::vector<int>>()[1];
}
}
}
virtual bool Run(cv::Mat* im, ImageBlob* data);
private:
int coarsest_stride_ = -1;
int width_ = 0;
int height_ = 0;
};
class Transforms {
public:
void Init(const YAML::Node& node, bool to_rgb = true);
std::shared_ptr<Transform> CreateTransform(const std::string& name);
bool Run(cv::Mat* im, ImageBlob* data);
bool Run(cv::Mat* im, Blob::Ptr blob);
private:
std::vector<std::shared_ptr<Transform>> transforms_;
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import os.path as osp
import cv2
import numpy as np
import yaml
from six import text_type as _text_type
from openvino.inference_engine import IECore
from utils import logging
class Predictor:
def __init__(self,
model_xml,
model_yaml,
device="CPU"):
self.device = device
if not osp.exists(model_xml):
logging.error("model xml file is not exists in {}".format(model_xml))
self.model_xml = model_xml
self.model_bin = osp.splitext(model_xml)[0] + ".bin"
if not osp.exists(model_yaml):
logging,error("model yaml file is not exists in {}".format(model_yaml))
with open(model_yaml) as f:
self.info = yaml.load(f.read(), Loader=yaml.Loader)
self.model_type = self.info['_Attributes']['model_type']
self.model_name = self.info['Model']
self.num_classes = self.info['_Attributes']['num_classes']
self.labels = self.info['_Attributes']['labels']
if self.info['Model'] == 'MaskRCNN':
if self.info['_init_params']['with_fpn']:
self.mask_head_resolution = 28
else:
self.mask_head_resolution = 14
transforms_mode = self.info.get('TransformsMode', 'RGB')
if transforms_mode == 'RGB':
to_rgb = True
else:
to_rgb = False
self.transforms = self.build_transforms(self.info['Transforms'], to_rgb)
self.predictor, self.net = self.create_predictor()
def create_predictor(self):
#initialization for specified device
logging.info("Creating Inference Engine")
ie = IECore()
logging.info("Loading network files:\n\t{}\n\t{}".format(self.model_xml, self.model_bin))
net = ie.read_network(model=self.model_xml, weights=self.model_bin)
net.batch_size = 1
exec_net = ie.load_network(network=net, device_name=self.device)
return exec_net, net
def build_transforms(self, transforms_info, to_rgb=True):
if self.model_type == "classifier":
import transforms.cls_transforms as transforms
elif self.model_type == "detector":
import transforms.det_transforms as transforms
elif self.model_type == "segmenter":
import transforms.seg_transforms as transforms
op_list = list()
for op_info in transforms_info:
op_name = list(op_info.keys())[0]
op_attr = op_info[op_name]
if not hasattr(transforms, op_name):
raise Exception(
"There's no operator named '{}' in transforms of {}".
format(op_name, self.model_type))
op_list.append(getattr(transforms, op_name)(**op_attr))
eval_transforms = transforms.Compose(op_list)
if hasattr(eval_transforms, 'to_rgb'):
eval_transforms.to_rgb = to_rgb
self.arrange_transforms(eval_transforms)
return eval_transforms
def arrange_transforms(self, eval_transforms):
if self.model_type == 'classifier':
import transforms.cls_transforms as transforms
arrange_transform = transforms.ArrangeClassifier
elif self.model_type == 'segmenter':
import transforms.det_transforms as transforms
arrange_transform = transforms.ArrangeSegmenter
elif self.model_type == 'detector':
import transforms.seg_transforms as transforms
arrange_name = 'Arrange{}'.format(self.model_name)
arrange_transform = getattr(transforms, arrange_name)
else:
raise Exception("Unrecognized model type: {}".format(
self.model_type))
if type(eval_transforms.transforms[-1]).__name__.startswith('Arrange'):
eval_transforms.transforms[-1] = arrange_transform(mode='test')
else:
eval_transforms.transforms.append(arrange_transform(mode='test'))
def raw_predict(self, images):
input_blob = next(iter(self.net.inputs))
out_blob = next(iter(self.net.outputs))
#Start sync inference
logging.info("Starting inference in synchronous mode")
res = self.predictor.infer(inputs={input_blob:images})
#Processing output blob
logging.info("Processing output blob")
res = res[out_blob]
print("res: ",res)
def preprocess(self, image):
if self.model_type == "classifier":
im, = self.transforms(image)
im = np.expand_dims(im, axis=0).copy()
#res['image'] = im
'''elif self.model_type == "detector":
if self.model_name == "YOLOv3":
im, im_shape = self.transforms(image)
im = np.expand_dims(im, axis=0).copy()
im_shape = np.expand_dims(im_shape, axis=0).copy()
res['image'] = im
res['im_size'] = im_shape
if self.model_name.count('RCNN') > 0:
im, im_resize_info, im_shape = self.transforms(image)
im = np.expand_dims(im, axis=0).copy()
im_resize_info = np.expand_dims(im_resize_info, axis=0).copy()
im_shape = np.expand_dims(im_shape, axis=0).copy()
res['image'] = im
res['im_info'] = im_resize_info
res['im_shape'] = im_shape
elif self.model_type == "segmenter":
im, im_info = self.transforms(image)
im = np.expand_dims(im, axis=0).copy()
res['image'] = im
res['im_info'] = im_info'''
return im
def predict(self, image, topk=1, threshold=0.5):
preprocessed_input = self.preprocess(image)
model_pred = self.raw_predict(preprocessed_input)
文件模式从 100755 更改为 100644
# openvino预编译库的路径
OPENVINO_DIR=$INTEL_OPENVINO_DIR/inference_engine
OPENVINO_DIR=/path/to/inference_engine/
# gflags预编译库的路径
GFLAGS_DIR=/wangsiyuan06/gflags/build
GFLAGS_DIR=/path/to/gflags
# ngraph lib的路径,编译openvino时通常会生成
NGRAPH_LIB=$INTEL_OPENVINO_DIR/deployment_tools/ngraph/lib
NGRAPH_LIB=/path/to/ngraph/lib/
# opencv预编译库的路径, 如果使用自带预编译版本可不修改
OPENCV_DIR=$(pwd)/deps/opencv3gcc4.8/
......
......@@ -22,7 +22,7 @@
#include "include/paddlex/paddlex.h"
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_string(cfg_file, "", "Path of PaddelX model yml file");
DEFINE_string(cfg_dir, "", "Path of inference model");
DEFINE_string(device, "CPU", "Device name");
DEFINE_string(image, "", "Path of test image file");
DEFINE_string(image_list, "", "Path of test image list file");
......@@ -35,8 +35,8 @@ int main(int argc, char** argv) {
std::cerr << "--model_dir need to be defined" << std::endl;
return -1;
}
if (FLAGS_cfg_file == "") {
std::cerr << "--cfg_file need to be defined" << std::endl;
if (FLAGS_cfg_dir == "") {
std::cerr << "--cfg_dir need to be defined" << std::endl;
return -1;
}
if (FLAGS_image == "" & FLAGS_image_list == "") {
......@@ -46,7 +46,7 @@ int main(int argc, char** argv) {
// 加载模型
PaddleX::Model model;
model.Init(FLAGS_model_dir, FLAGS_cfg_file, FLAGS_device);
model.Init(FLAGS_model_dir, FLAGS_cfg_dir, FLAGS_device);
// 进行预测
if (FLAGS_image_list != "") {
......@@ -62,7 +62,7 @@ int main(int argc, char** argv) {
model.predict(im, &result);
std::cout << "Predict label: " << result.category
<< ", label_id:" << result.category_id
<< ", score: " << result.score << std::endl;
<< ", score: " << result.score << std::endl;
}
} else {
PaddleX::ClsResult result;
......
......@@ -13,8 +13,6 @@
// limitations under the License.
#include "include/paddlex/paddlex.h"
#include <iostream>
#include <fstream>
using namespace InferenceEngine;
......@@ -52,24 +50,20 @@ bool Model::load_config(const std::string& cfg_dir) {
}
// 构建数据处理流
transforms_.Init(config["Transforms"], to_rgb);
// 读入label lis
for (const auto& item : config["_Attributes"]["labels"]) {
int index = labels.size();
labels[index] = item.as<std::string>();
}
// 读入label list
labels.clear();
labels = config["_Attributes"]["labels"].as<std::vector<std::string>>();
return true;
}
bool Model::preprocess(cv::Mat* input_im, ImageBlob* inputs) {
if (!transforms_.Run(input_im, inputs)) {
bool Model::preprocess(cv::Mat* input_im) {
if (!transforms_.Run(input_im, inputs_)) {
return false;
}
return true;
}
bool Model::predict(const cv::Mat& im, ClsResult* result) {
inputs_.clear();
if (type == "detector") {
std::cerr << "Loading model is a 'detector', DetResult should be passed to "
"function predict()!"
......@@ -84,17 +78,17 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) {
// 处理输入图像
InferRequest infer_request = executable_network_.CreateInferRequest();
std::string input_name = network_.getInputsInfo().begin()->first;
inputs_.blob = infer_request.GetBlob(input_name);
cv::Mat im_clone = im.clone();
if (!preprocess(&im_clone, &inputs_)) {
inputs_ = infer_request.GetBlob(input_name);
auto im_clone = im.clone();
if (!preprocess(&im_clone)) {
std::cerr << "Preprocess failed!" << std::endl;
return false;
}
infer_request.Infer();
std::string output_name = network_.getOutputsInfo().begin()->first;
std::cout << "ouput node name" << output_name << std::endl;
output_ = infer_request.GetBlob(output_name);
MemoryBlob::CPtr moutput = as<MemoryBlob>(output_);
auto moutputHolder = moutput->rmap();
......@@ -105,122 +99,10 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) {
result->category_id = std::distance(outputs_data, ptr);
result->score = *ptr;
result->category = labels[result->category_id];
//for (int i=0;i<sizeof(outputs_data);i++){
// std::cout << labels[i] << std::endl;
// std::cout << outputs_[i] << std::endl;
// }
}
bool Model::predict(const cv::Mat& im, SegResult* result) {
result->clear();
inputs_.clear();
if (type == "classifier") {
std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
"to function predict()!" << std::endl;
return false;
} else if (type == "detector") {
std::cerr << "Loading model is a 'detector', DetResult should be passed to "
"function predict()!" << std::endl;
return false;
}
//
InferRequest infer_request = executable_network_.CreateInferRequest();
std::string input_name = network_.getInputsInfo().begin()->first;
inputs_.blob = infer_request.GetBlob(input_name);
//
cv::Mat im_clone = im.clone();
if (!preprocess(&im_clone, &inputs_)) {
std::cerr << "Preprocess failed!" << std::endl;
return false;
}
//
infer_request.Infer();
OutputsDataMap out_map = network_.getOutputsInfo();
auto iter = out_map.begin();
iter++;
std::string output_name_score = iter->first;
Blob::Ptr output_score = infer_request.GetBlob(output_name_score);
MemoryBlob::CPtr moutput_score = as<MemoryBlob>(output_score);
TensorDesc blob_score = moutput_score->getTensorDesc();
std::vector<size_t> output_score_shape = blob_score.getDims();
int size = 1;
for (auto& i : output_score_shape) {
size *= static_cast<int>(i);
result->score_map.shape.push_back(static_cast<int>(i));
}
result->score_map.data.resize(size);
auto moutputHolder_score = moutput_score->rmap();
float* score_data = moutputHolder_score.as<float *>();
memcpy(result->score_map.data.data(),score_data,moutput_score->byteSize());
iter++;
std::string output_name_label = iter->first;
Blob::Ptr output_label = infer_request.GetBlob(output_name_label);
MemoryBlob::CPtr moutput_label = as<MemoryBlob>(output_label);
TensorDesc blob_label = moutput_label->getTensorDesc();
std::vector<size_t> output_label_shape = blob_label.getDims();
size = 1;
for (auto& i : output_label_shape) {
size *= static_cast<int>(i);
result->label_map.shape.push_back(static_cast<int>(i));
}
result->label_map.data.resize(size);
auto moutputHolder_label = moutput_label->rmap();
int* label_data = moutputHolder_label.as<int *>();
memcpy(result->label_map.data.data(),label_data,moutput_label->byteSize());
std::vector<uint8_t> label_map(result->label_map.data.begin(),
result->label_map.data.end());
cv::Mat mask_label(result->label_map.shape[1],
result->label_map.shape[2],
CV_8UC1,
label_map.data());
cv::Mat mask_score(result->score_map.shape[2],
result->score_map.shape[3],
CV_32FC1,
result->score_map.data.data());
int idx = 1;
int len_postprocess = inputs_.im_size_before_resize_.size();
for (std::vector<std::string>::reverse_iterator iter =
inputs_.reshape_order_.rbegin();
iter != inputs_.reshape_order_.rend();
++iter) {
if (*iter == "padding") {
auto before_shape = inputs_.im_size_before_resize_[len_postprocess - idx];
inputs_.im_size_before_resize_.pop_back();
auto padding_w = before_shape[0];
auto padding_h = before_shape[1];
mask_label = mask_label(cv::Rect(0, 0, padding_h, padding_w));
mask_score = mask_score(cv::Rect(0, 0, padding_h, padding_w));
} else if (*iter == "resize") {
auto before_shape = inputs_.im_size_before_resize_[len_postprocess - idx];
inputs_.im_size_before_resize_.pop_back();
auto resize_w = before_shape[0];
auto resize_h = before_shape[1];
cv::resize(mask_label,
mask_label,
cv::Size(resize_h, resize_w),
0,
0,
cv::INTER_NEAREST);
cv::resize(mask_score,
mask_score,
cv::Size(resize_h, resize_w),
0,
0,
cv::INTER_LINEAR);
}
++idx;
}
result->label_map.data.assign(mask_label.begin<uint8_t>(),
mask_label.end<uint8_t>());
result->label_map.shape = {mask_label.rows, mask_label.cols};
result->score_map.data.assign(mask_score.begin<float>(),
mask_score.end<float>());
result->score_map.shape = {mask_score.rows, mask_score.cols};
return true;
}
} // namespce of PaddleX
......@@ -13,7 +13,6 @@
// limitations under the License.
#include <iostream>
#include <fstream>
#include <string>
#include <vector>
......@@ -27,7 +26,7 @@ std::map<std::string, int> interpolations = {{"LINEAR", cv::INTER_LINEAR},
{"CUBIC", cv::INTER_CUBIC},
{"LANCZOS4", cv::INTER_LANCZOS4}};
bool Normalize::Run(cv::Mat* im, ImageBlob* data){
bool Normalize::Run(cv::Mat* im){
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
......@@ -41,6 +40,19 @@ bool Normalize::Run(cv::Mat* im, ImageBlob* data){
return true;
}
bool CenterCrop::Run(cv::Mat* im) {
int height = static_cast<int>(im->rows);
int width = static_cast<int>(im->cols);
if (height < height_ || width < width_) {
std::cerr << "[CenterCrop] Image size less than crop size" << std::endl;
return false;
}
int offset_x = static_cast<int>((width - width_) / 2);
int offset_y = static_cast<int>((height - height_) / 2);
cv::Rect crop_roi(offset_x, offset_y, width_, height_);
*im = (*im)(crop_roi);
return true;
}
float ResizeByShort::GenerateScale(const cv::Mat& im) {
......@@ -58,109 +70,11 @@ float ResizeByShort::GenerateScale(const cv::Mat& im) {
return scale;
}
bool ResizeByShort::Run(cv::Mat* im, ImageBlob* data) {
data->im_size_before_resize_.push_back({im->rows, im->cols});
data->reshape_order_.push_back("resize");
bool ResizeByShort::Run(cv::Mat* im) {
float scale = GenerateScale(*im);
int width = static_cast<int>(scale * im->cols);
int height = static_cast<int>(scale * im->rows);
cv::resize(*im, *im, cv::Size(width, height), 0, 0, cv::INTER_LINEAR);
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
data->scale = scale;
return true;
}
bool CenterCrop::Run(cv::Mat* im, ImageBlob* data) {
int height = static_cast<int>(im->rows);
int width = static_cast<int>(im->cols);
if (height < height_ || width < width_) {
std::cerr << "[CenterCrop] Image size less than crop size" << std::endl;
return false;
}
int offset_x = static_cast<int>((width - width_) / 2);
int offset_y = static_cast<int>((height - height_) / 2);
cv::Rect crop_roi(offset_x, offset_y, width_, height_);
*im = (*im)(crop_roi);
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
return true;
}
bool Padding::Run(cv::Mat* im, ImageBlob* data) {
data->im_size_before_resize_.push_back({im->rows, im->cols});
data->reshape_order_.push_back("padding");
int padding_w = 0;
int padding_h = 0;
if (width_ > 1 & height_ > 1) {
padding_w = width_ - im->cols;
padding_h = height_ - im->rows;
} else if (coarsest_stride_ >= 1) {
int h = im->rows;
int w = im->cols;
padding_h =
ceil(h * 1.0 / coarsest_stride_) * coarsest_stride_ - im->rows;
padding_w =
ceil(w * 1.0 / coarsest_stride_) * coarsest_stride_ - im->cols;
}
if (padding_h < 0 || padding_w < 0) {
std::cerr << "[Padding] Computed padding_h=" << padding_h
<< ", padding_w=" << padding_w
<< ", but they should be greater than 0." << std::endl;
return false;
}
cv::copyMakeBorder(
*im, *im, 0, padding_h, 0, padding_w, cv::BORDER_CONSTANT, cv::Scalar(0));
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
return true;
}
bool ResizeByLong::Run(cv::Mat* im, ImageBlob* data) {
if (long_size_ <= 0) {
std::cerr << "[ResizeByLong] long_size should be greater than 0"
<< std::endl;
return false;
}
data->im_size_before_resize_.push_back({im->rows, im->cols});
data->reshape_order_.push_back("resize");
int origin_w = im->cols;
int origin_h = im->rows;
int im_size_max = std::max(origin_w, origin_h);
float scale =
static_cast<float>(long_size_) / static_cast<float>(im_size_max);
cv::resize(*im, *im, cv::Size(), scale, scale, cv::INTER_NEAREST);
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
data->scale = scale;
return true;
}
bool Resize::Run(cv::Mat* im, ImageBlob* data) {
if (width_ <= 0 || height_ <= 0) {
std::cerr << "[Resize] width and height should be greater than 0"
<< std::endl;
return false;
}
if (interpolations.count(interp_) <= 0) {
std::cerr << "[Resize] Invalid interpolation method: '" << interp_ << "'"
<< std::endl;
return false;
}
data->im_size_before_resize_.push_back({im->rows, im->cols});
data->reshape_order_.push_back("resize");
cv::resize(
*im, *im, cv::Size(width_, height_), 0, 0, interpolations[interp_]);
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
return true;
}
......@@ -180,16 +94,10 @@ std::shared_ptr<Transform> Transforms::CreateTransform(
const std::string& transform_name) {
if (transform_name == "Normalize") {
return std::make_shared<Normalize>();
} else if (transform_name == "ResizeByShort") {
return std::make_shared<ResizeByShort>();
} else if (transform_name == "CenterCrop") {
return std::make_shared<CenterCrop>();
} else if (transform_name == "Resize") {
return std::make_shared<Resize>();
} else if (transform_name == "Padding") {
return std::make_shared<Padding>();
} else if (transform_name == "ResizeByLong") {
return std::make_shared<ResizeByLong>();
} else if (transform_name == "ResizeByShort") {
return std::make_shared<ResizeByShort>();
} else {
std::cerr << "There's unexpected transform(name='" << transform_name
<< "')." << std::endl;
......@@ -197,20 +105,15 @@ std::shared_ptr<Transform> Transforms::CreateTransform(
}
}
bool Transforms::Run(cv::Mat* im, ImageBlob* data) {
bool Transforms::Run(cv::Mat* im, Blob::Ptr blob) {
// 按照transforms中预处理算子顺序处理图像
if (to_rgb_) {
cv::cvtColor(*im, *im, cv::COLOR_BGR2RGB);
}
(*im).convertTo(*im, CV_32FC3);
data->ori_im_size_[0] = im->rows;
data->ori_im_size_[1] = im->cols;
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
for (int i = 0; i < transforms_.size(); ++i) {
if (!transforms_[i]->Run(im,data)) {
if (!transforms_[i]->Run(im)) {
std::cerr << "Apply transforms to image failed!" << std::endl;
return false;
}
......@@ -218,15 +121,13 @@ bool Transforms::Run(cv::Mat* im, ImageBlob* data) {
// 将图像由NHWC转为NCHW格式
// 同时转为连续的内存块存储到Blob
SizeVector blobSize = data->blob->getTensorDesc().getDims();
SizeVector blobSize = blob->getTensorDesc().getDims();
const size_t width = blobSize[3];
const size_t height = blobSize[2];
const size_t channels = blobSize[1];
MemoryBlob::Ptr mblob = InferenceEngine::as<MemoryBlob>(data->blob);
MemoryBlob::Ptr mblob = InferenceEngine::as<MemoryBlob>(blob);
auto mblobHolder = mblob->wmap();
float *blob_data = mblobHolder.as<float *>();
for (size_t c = 0; c < channels; c++) {
for (size_t h = 0; h < height; h++) {
for (size_t w = 0; w < width; w++) {
......
cmake_minimum_required(VERSION 3.0)
project(PaddleX CXX C)
option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." OFF)
SET(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake" ${CMAKE_MODULE_PATH})
SET(LITE_DIR "" CACHE PATH "Location of libraries")
SET(OPENCV_DIR "" CACHE PATH "Location of libraries")
SET(NGRAPH_LIB "" CACHE PATH "Location of libraries")
include(cmake/yaml-cpp.cmake)
include_directories("${CMAKE_SOURCE_DIR}/")
link_directories("${CMAKE_CURRENT_BINARY_DIR}")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/ext/yaml-cpp/src/ext-yaml-cpp/include")
link_directories("${CMAKE_CURRENT_BINARY_DIR}/ext/yaml-cpp/lib")
macro(safe_set_static_flag)
foreach(flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO)
if(${flag_var} MATCHES "/MD")
string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}")
endif(${flag_var} MATCHES "/MD")
endforeach(flag_var)
endmacro()
if (NOT DEFINED LITE_DIR OR ${LITE_DIR} STREQUAL "")
message(FATAL_ERROR "please set LITE_DIR with -LITE_DIR=/path/influence_engine")
endif()
if (NOT DEFINED OPENCV_DIR OR ${OPENCV_DIR} STREQUAL "")
message(FATAL_ERROR "please set OPENCV_DIR with -DOPENCV_DIR=/path/opencv")
endif()
if (NOT DEFINED GFLAGS_DIR OR ${GFLAGS_DIR} STREQUAL "")
message(FATAL_ERROR "please set GFLAGS_DIR with -DGFLAGS_DIR=/path/gflags")
endif()
link_directories("${LITE_DIR}/lib")
include_directories("${LITE_DIR}/include")
link_directories("${GFLAGS_DIR}/lib")
include_directories("${GFLAGS_DIR}/include")
if (WIN32)
find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/build/ NO_DEFAULT_PATH)
unset(OpenCV_DIR CACHE)
else ()
find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/cmake NO_DEFAULT_PATH)
endif ()
include_directories(${OpenCV_INCLUDE_DIRS})
if (WIN32)
add_definitions("/DGOOGLE_GLOG_DLL_DECL=")
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /bigobj /MTd")
set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /bigobj /MT")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /bigobj /MTd")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /bigobj /MT")
if (WITH_STATIC_LIB)
safe_set_static_flag()
add_definitions(-DSTATIC_LIB)
endif()
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfloat-abi=hard -mfpu=neon-vfpv4 -g -o2 -fopenmp -std=c++11")
set(CMAKE_STATIC_LIBRARY_PREFIX "")
endif()
if(WITH_STATIC_LIB)
set(DEPS ${LITE_DIR}/lib/libpaddle_full_api_shared${CMAKE_STATIC_LIBRARY_SUFFIX})
else()
set(DEPS ${LITE_DIR}/lib/libpaddle_full_api_shared${CMAKE_SHARED_LIBRARY_SUFFIX})
endif()
if (NOT WIN32)
set(DEPS ${DEPS}
glog gflags z yaml-cpp
)
else()
set(DEPS ${DEPS}
glog gflags_static libprotobuf zlibstatic xxhash libyaml-cppmt)
set(DEPS ${DEPS} libcmt shlwapi)
endif(NOT WIN32)
if (NOT WIN32)
set(EXTERNAL_LIB "-ldl -lrt -lgomp -lz -lm -lpthread")
set(DEPS ${DEPS} ${EXTERNAL_LIB})
endif()
set(DEPS ${DEPS} ${OpenCV_LIBS})
add_executable(classifier demo/classifier.cpp src/transforms.cpp src/paddlex.cpp)
ADD_DEPENDENCIES(classifier ext-yaml-cpp)
target_link_libraries(classifier ${DEPS})
add_executable(segmenter demo/segmenter.cpp src/transforms.cpp src/paddlex.cpp src/visualize.cpp)
ADD_DEPENDENCIES(segmenter ext-yaml-cpp)
target_link_libraries(segmenter ${DEPS})
add_executable(detector demo/detector.cpp src/transforms.cpp src/paddlex.cpp src/visualize.cpp)
ADD_DEPENDENCIES(detector ext-yaml-cpp)
target_link_libraries(detector ${DEPS})
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
find_package(Git REQUIRED)
include(ExternalProject)
message("${CMAKE_BUILD_TYPE}")
ExternalProject_Add(
ext-yaml-cpp
URL https://bj.bcebos.com/paddlex/deploy/deps/yaml-cpp.zip
URL_MD5 9542d6de397d1fbd649ed468cb5850e6
CMAKE_ARGS
-DYAML_CPP_BUILD_TESTS=OFF
-DYAML_CPP_BUILD_TOOLS=OFF
-DYAML_CPP_INSTALL=OFF
-DYAML_CPP_BUILD_CONTRIB=OFF
-DMSVC_SHARED_RT=OFF
-DBUILD_SHARED_LIBS=OFF
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=${CMAKE_BINARY_DIR}/ext/yaml-cpp/lib
-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY=${CMAKE_BINARY_DIR}/ext/yaml-cpp/lib
PREFIX "${CMAKE_BINARY_DIR}/ext/yaml-cpp"
# Disable install step
INSTALL_COMMAND ""
LOG_DOWNLOAD ON
LOG_BUILD 1
)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include "include/paddlex/paddlex.h"
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_string(cfg_dir, "", "Path of PaddelX model yml file");
DEFINE_string(image, "", "Path of test image file");
DEFINE_string(image_list, "", "Path of test image list file");
DEFINE_int32(thread_num, 1, "num of thread to infer");
int main(int argc, char** argv) {
// Parsing command-line
google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_model_dir == "") {
std::cerr << "--model_dir need to be defined" << std::endl;
return -1;
}
if (FLAGS_cfg_dir == "") {
std::cerr << "--cfg_dir need to be defined" << std::endl;
return -1;
}
if (FLAGS_image == "" & FLAGS_image_list == "") {
std::cerr << "--image or --image_list need to be defined" << std::endl;
return -1;
}
// 加载模型
PaddleX::Model model;
model.Init(FLAGS_model_dir, FLAGS_cfg_dir, FLAGS_thread_num);
std::cout << "init is done" << std::endl;
// 进行预测
if (FLAGS_image_list != "") {
std::ifstream inf(FLAGS_image_list);
if (!inf) {
std::cerr << "Fail to open file " << FLAGS_image_list << std::endl;
return -1;
}
std::string image_path;
while (getline(inf, image_path)) {
PaddleX::ClsResult result;
cv::Mat im = cv::imread(image_path, 1);
model.predict(im, &result);
std::cout << "Predict label: " << result.category
<< ", label_id:" << result.category_id
<< ", score: " << result.score
<< ", num_img: " << model.count_num_ << std::endl;
}
} else {
PaddleX::ClsResult result;
cv::Mat im = cv::imread(FLAGS_image, 1);
model.predict(im, &result);
std::cout << "Predict label: " << result.category
<< ", label_id:" << result.category_id
<< ", score: " << result.score << std::endl;
}
return 0;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <omp.h>
#include <algorithm>
#include <chrono> // NOLINT
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include <utility>
#include "include/paddlex/paddlex.h"
#include "include/paddlex/visualize.h"
using namespace std::chrono; // NOLINT
DEFINE_string(model_dir, "", "Path of openvino model xml file");
DEFINE_string(cfg_dir, "", "Path of PaddleX model yaml file");
DEFINE_string(image, "", "Path of test image file");
DEFINE_string(image_list, "", "Path of test image list file");
DEFINE_int32(thread_num, 1, "num of thread to infer");
DEFINE_string(save_dir, "", "Path to save visualized image");
DEFINE_int32(batch_size, 1, "Batch size of infering");
DEFINE_double(threshold,
0.5,
"The minimum scores of target boxes which are shown");
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_model_dir == "") {
std::cerr << "--model_dir need to be defined" << std::endl;
return -1;
}
if (FLAGS_cfg_dir == "") {
std::cerr << "--cfg_dir need to be defined" << std::endl;
return -1;
}
if (FLAGS_image == "" & FLAGS_image_list == "") {
std::cerr << "--image or --image_list need to be defined" << std::endl;
return -1;
}
//
PaddleX::Model model;
model.Init(FLAGS_model_dir, FLAGS_cfg_dir, FLAGS_thread_num);
int imgs = 1;
auto colormap = PaddleX::GenerateColorMap(model.labels.size());
// 进行预测
if (FLAGS_image_list != "") {
std::ifstream inf(FLAGS_image_list);
if(!inf){
std::cerr << "Fail to open file " << FLAGS_image_list << std::endl;
return -1;
}
std::string image_path;
while (getline(inf, image_path)) {
PaddleX::DetResult result;
cv::Mat im = cv::imread(image_path, 1);
model.predict(im, &result);
if(FLAGS_save_dir != ""){
cv::Mat vis_img =
PaddleX::Visualize(im, result, model.labels, colormap, FLAGS_threshold);
std::string save_path =
PaddleX::generate_save_path(FLAGS_save_dir, FLAGS_image);
cv::imwrite(save_path, vis_img);
std::cout << "Visualized output saved as " << save_path << std::endl;
}
}
}else {
PaddleX::DetResult result;
cv::Mat im = cv::imread(FLAGS_image, 1);
model.predict(im, &result);
for (int i = 0; i < result.boxes.size(); ++i) {
std::cout << "image file: " << FLAGS_image << std::endl;
std::cout << ", predict label: " << result.boxes[i].category
<< ", label_id:" << result.boxes[i].category_id
<< ", score: " << result.boxes[i].score
<< ", box(xmin, ymin, w, h):(" << result.boxes[i].coordinate[0]
<< ", " << result.boxes[i].coordinate[1] << ", "
<< result.boxes[i].coordinate[2] << ", "
<< result.boxes[i].coordinate[3] << ")" << std::endl;
}
if(FLAGS_save_dir != ""){
// 可视化
cv::Mat vis_img =
PaddleX::Visualize(im, result, model.labels, colormap, FLAGS_threshold);
std::string save_path =
PaddleX::generate_save_path(FLAGS_save_dir, FLAGS_image);
cv::imwrite(save_path, vis_img);
result.clear();
std::cout << "Visualized output saved as " << save_path << std::endl;
}
}
return 0;
}
......@@ -25,12 +25,12 @@
DEFINE_string(model_dir, "", "Path of openvino model xml file");
DEFINE_string(cfg_file, "", "Path of PaddleX model yaml file");
DEFINE_string(cfg_dir, "", "Path of PaddleX model yaml file");
DEFINE_string(image, "", "Path of test image file");
DEFINE_string(image_list, "", "Path of test image list file");
DEFINE_string(device, "CPU", "Device name");
DEFINE_string(save_dir, "", "Path to save visualized image");
DEFINE_int32(batch_size, 1, "Batch size of infering");
DEFINE_int32(thread_num, 1, "num of thread to infer");
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
......@@ -38,8 +38,8 @@ int main(int argc, char** argv) {
std::cerr << "--model_dir need to be defined" << std::endl;
return -1;
}
if (FLAGS_cfg_file == "") {
std::cerr << "--cfg_file need to be defined" << std::endl;
if (FLAGS_cfg_dir == "") {
std::cerr << "--cfg_dir need to be defined" << std::endl;
return -1;
}
if (FLAGS_image == "" & FLAGS_image_list == "") {
......@@ -48,8 +48,10 @@ int main(int argc, char** argv) {
}
//
std::cout << "init start" << std::endl;
PaddleX::Model model;
model.Init(FLAGS_model_dir, FLAGS_cfg_file, FLAGS_device);
model.Init(FLAGS_model_dir, FLAGS_cfg_dir, FLAGS_thread_num);
std::cout << "init done" << std::endl;
int imgs = 1;
auto colormap = PaddleX::GenerateColorMap(model.labels.size());
......@@ -60,6 +62,7 @@ int main(int argc, char** argv) {
return -1;
}
std::string image_path;
while (getline(inf, image_path)) {
PaddleX::SegResult result;
cv::Mat im = cv::imread(image_path, 1);
......@@ -71,7 +74,9 @@ int main(int argc, char** argv) {
cv::imwrite(save_path, vis_img);
std::cout << "Visualized output saved as " << save_path << std::endl;
}
}
}else{
PaddleX::SegResult result;
cv::Mat im = cv::imread(FLAGS_image, 1);
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iostream>
#include <map>
#include <string>
#include <vector>
#include "yaml-cpp/yaml.h"
#ifdef _WIN32
#define OS_PATH_SEP "\\"
#else
#define OS_PATH_SEP "/"
#endif
namespace PaddleX {
// Inference model configuration parser
class ConfigPaser {
public:
ConfigPaser() {}
~ConfigPaser() {}
bool load_config(const std::string& model_dir,
const std::string& cfg = "model.yml") {
// Load as a YAML::Node
YAML::Node config;
config = YAML::LoadFile(model_dir + OS_PATH_SEP + cfg);
if (config["Transforms"].IsDefined()) {
YAML::Node transforms_ = config["Transforms"];
} else {
std::cerr << "There's no field 'Transforms' in model.yml" << std::endl;
return false;
}
return true;
}
YAML::Node Transforms_;
};
} // namespace PaddleDetection
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <iostream>
#include <numeric>
#include <chrono>
#include "yaml-cpp/yaml.h"
#ifdef _WIN32
#define OS_PATH_SEP "\\"
#else
#define OS_PATH_SEP "/"
#endif
#include "paddle_api.h"
#include <arm_neon.h>
#include "include/paddlex/config_parser.h"
#include "include/paddlex/results.h"
#include "include/paddlex/transforms.h"
using namespace paddle::lite_api;
namespace PaddleX {
class Model {
public:
void Init(const std::string& model_dir,
const std::string& cfg_dir,
int thread_num) {
create_predictor(model_dir, cfg_dir, thread_num);
}
void create_predictor(const std::string& model_dir,
const std::string& cfg_dir,
int thread_num);
bool load_config(const std::string& model_dir);
bool preprocess(cv::Mat* input_im, ImageBlob* inputs);
bool predict(const cv::Mat& im, ClsResult* result);
bool predict(const cv::Mat& im, DetResult* result);
bool predict(const cv::Mat& im, SegResult* result);
std::string type;
std::string name;
std::map<int, std::string> labels;
Transforms transforms_;
ImageBlob inputs_;
std::shared_ptr<PaddlePredictor> predictor_;
};
} // namespce of PaddleX
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iostream>
#include <string>
#include <vector>
namespace PaddleX {
template <class T>
struct Mask {
std::vector<T> data;
std::vector<int> shape;
void clear() {
data.clear();
shape.clear();
}
};
struct Box {
int category_id;
std::string category;
float score;
std::vector<float> coordinate;
Mask<float> mask;
};
class BaseResult {
public:
std::string type = "base";
};
class ClsResult : public BaseResult {
public:
int category_id;
std::string category;
float score;
std::string type = "cls";
};
class DetResult : public BaseResult {
public:
std::vector<Box> boxes;
int mask_resolution;
std::string type = "det";
void clear() { boxes.clear(); }
};
class SegResult : public BaseResult {
public:
Mask<int64_t> label_map;
Mask<float> score_map;
void clear() {
label_map.clear();
score_map.clear();
}
};
} // namespce of PaddleX
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <yaml-cpp/yaml.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include "paddle_api.h"
using namespace paddle::lite_api;
namespace PaddleX {
/*
* @brief
* This class represents object for storing all preprocessed data
* */
class ImageBlob {
public:
// Original image height and width
std::vector<int> ori_im_size_ = std::vector<int>(2);
// Newest image height and width after process
std::vector<int> new_im_size_ = std::vector<int>(2);
// Image height and width before resize
std::vector<std::vector<int>> im_size_before_resize_;
// Reshape order
std::vector<std::string> reshape_order_;
// Resize scale
float scale = 1.0;
// Buffer for image data after preprocessing
std::unique_ptr<Tensor> input_tensor_;
void clear() {
im_size_before_resize_.clear();
reshape_order_.clear();
}
};
// Abstraction of preprocessing opration class
class Transform {
public:
virtual void Init(const YAML::Node& item) = 0;
virtual bool Run(cv::Mat* im, ImageBlob* data) = 0;
};
class Normalize : public Transform {
public:
virtual void Init(const YAML::Node& item) {
mean_ = item["mean"].as<std::vector<float>>();
std_ = item["std"].as<std::vector<float>>();
}
virtual bool Run(cv::Mat* im, ImageBlob* data);
private:
std::vector<float> mean_;
std::vector<float> std_;
};
class ResizeByShort : public Transform {
public:
virtual void Init(const YAML::Node& item) {
short_size_ = item["short_size"].as<int>();
if (item["max_size"].IsDefined()) {
max_size_ = item["max_size"].as<int>();
} else {
max_size_ = -1;
}
};
virtual bool Run(cv::Mat* im, ImageBlob* data);
private:
float GenerateScale(const cv::Mat& im);
int short_size_;
int max_size_;
};
/*
* @brief
* This class execute resize by long operation on image matrix. At first, it resizes
* the long side of image matrix to specified length. Accordingly, the short side
* will be resized in the same proportion.
* */
class ResizeByLong : public Transform {
public:
virtual void Init(const YAML::Node& item) {
long_size_ = item["long_size"].as<int>();
}
virtual bool Run(cv::Mat* im, ImageBlob* data);
private:
int long_size_;
};
/*
* @brief
* This class execute resize operation on image matrix. It resizes width and height
* to specified length.
* */
class Resize : public Transform {
public:
virtual void Init(const YAML::Node& item) {
if (item["interp"].IsDefined()) {
interp_ = item["interp"].as<std::string>();
}
if (item["target_size"].IsScalar()) {
height_ = item["target_size"].as<int>();
width_ = item["target_size"].as<int>();
} else if (item["target_size"].IsSequence()) {
std::vector<int> target_size = item["target_size"].as<std::vector<int>>();
width_ = target_size[0];
height_ = target_size[1];
}
if (height_ <= 0 || width_ <= 0) {
std::cerr << "[Resize] target_size should greater than 0" << std::endl;
exit(-1);
}
}
virtual bool Run(cv::Mat* im, ImageBlob* data);
private:
int height_;
int width_;
std::string interp_;
};
class CenterCrop : public Transform {
public:
virtual void Init(const YAML::Node& item) {
if (item["crop_size"].IsScalar()) {
height_ = item["crop_size"].as<int>();
width_ = item["crop_size"].as<int>();
} else if (item["crop_size"].IsSequence()) {
std::vector<int> crop_size = item["crop_size"].as<std::vector<int>>();
width_ = crop_size[0];
height_ = crop_size[1];
}
}
virtual bool Run(cv::Mat* im, ImageBlob* data);
private:
int height_;
int width_;
};
/*
* @brief
* This class execute padding operation on image matrix. It makes border on edge
* of image matrix.
* */
class Padding : public Transform {
public:
virtual void Init(const YAML::Node& item) {
if (item["coarsest_stride"].IsDefined()) {
coarsest_stride_ = item["coarsest_stride"].as<int>();
if (coarsest_stride_ < 1) {
std::cerr << "[Padding] coarest_stride should greater than 0"
<< std::endl;
exit(-1);
}
}
if (item["target_size"].IsDefined()) {
if (item["target_size"].IsScalar()) {
width_ = item["target_size"].as<int>();
height_ = item["target_size"].as<int>();
} else if (item["target_size"].IsSequence()) {
width_ = item["target_size"].as<std::vector<int>>()[0];
height_ = item["target_size"].as<std::vector<int>>()[1];
}
}
if (item["im_padding_value"].IsDefined()) {
im_value_ = item["im_padding_value"].as<std::vector<float>>();
}
else {
im_value_ = {0, 0, 0};
}
}
virtual bool Run(cv::Mat* im, ImageBlob* data);
private:
int coarsest_stride_ = -1;
int width_ = 0;
int height_ = 0;
std::vector<float> im_value_;
};
class Transforms {
public:
void Init(const YAML::Node& node, bool to_rgb = true);
std::shared_ptr<Transform> CreateTransform(const std::string& name);
bool Run(cv::Mat* im, ImageBlob* data);
private:
std::vector<std::shared_ptr<Transform>> transforms_;
bool to_rgb_ = true;
};
} // namespace PaddleX
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from six import text_type as _text_type
import argparse
......
......@@ -26,13 +26,6 @@ def arg_parser():
type=str,
default=None,
help="path to openvino model .xml file")
parser.add_argument(
"--device",
"-d",
type=str,
default='CPU',
help="Specify the target device to infer on:[CPU, GPU, FPGA, HDDL, MYRIAD,HETERO]"
"Default value is CPU")
parser.add_argument(
"--img",
"-i",
......@@ -49,27 +42,43 @@ def arg_parser():
parser.add_argument(
"--cfg_file",
"--cfg_dir",
"-c",
type=str,
default=None,
help="Path to PaddelX model yml file")
parser.add_argument(
"--thread_num",
"-t",
type=int,
default=1,
help="Path to PaddelX model yml file")
parser.add_argument(
"--input_shape",
"-ip",
type=str,
default=None,
help=" image input shape of model [NCHW] like [1,3,224,244] ")
return parser
def main():
parser = arg_parser()
args = parser.parse_args()
model_xml = args.model_dir
model_yaml = args.cfg_file
model_nb = args.model_dir
model_yaml = args.cfg_dir
thread_num = args.thread_num
input_shape = args.input_shape
input_shape = input_shape[1:-1].split(",",3)
shape = list(map(int,input_shape))
#model init
if("CPU" not in args.device):
predictor = deploy.Predictor(model_xml,model_yaml,args.device)
else:
predictor = deploy.Predictor(model_xml,model_yaml)
predictor = deploy.Predictor(model_nb,model_yaml,thread_num,shape)
#predict
if(args.img_list != None):
......
# Paddle-Lite预编译库的路径
LITE_DIR=/path/to/Paddle-Lite/inference/lib
# gflags预编译库的路径
GFLAGS_DIR=$(pwd)/deps/gflags
# glog预编译库的路径
GLOG_DIR=$(pwd)/deps/glog
# opencv预编译库的路径, 如果使用自带预编译版本可不修改
OPENCV_DIR=$(pwd)/deps/opencv
# 下载自带预编译版本
exec $(pwd)/scripts/install_third-party.sh
rm -rf build
mkdir -p build
cd build
cmake .. \
-DOPENCV_DIR=${OPENCV_DIR} \
-DGFLAGS_DIR=${GFLAGS_DIR} \
-DLITE_DIR=${LITE_DIR} \
-DCMAKE_CXX_FLAGS="-march=armv7-a"
make
# download third-part lib
if [ ! -d "./deps" ]; then
mkdir deps
fi
if [ ! -d "./deps/gflag" ]; then
cd deps
git clone https://github.com/gflags/gflags
cd gflags
cmake .
make -j 4
cd ..
cd ..
fi
if [ ! -d "./deps/glog" ]; then
cd deps
git clone https://github.com/google/glog
sudo apt-get install autoconf automake libtool
cd glog
./autogen.sh
./configure
make -j 4
cd ..
cd ..
fi
OPENCV_URL=https://bj.bcebos.com/paddlex/deploy/armopencv/opencv.tar.bz2
if [ ! -d "./deps/opencv" ]; then
cd deps
wget -c ${OPENCV_URL}
tar xvfj opencv.tar.bz2
rm -rf opencv.tar.bz2
cd ..
fi
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "include/paddlex/paddlex.h"
#include <iostream>
#include <fstream>
using namespace paddle::lite_api;
namespace PaddleX {
void Model::create_predictor(const std::string& model_dir,
const std::string& cfg_dir,
int thread_num) {
MobileConfig config;
config.set_model_from_file(model_dir);
config.set_threads(thread_num);
load_config(cfg_dir);
predictor_ = CreatePaddlePredictor<MobileConfig>(config);
}
bool Model::load_config(const std::string& cfg_dir) {
YAML::Node config = YAML::LoadFile(cfg_dir);
type = config["_Attributes"]["model_type"].as<std::string>();
name = config["Model"].as<std::string>();
bool to_rgb = true;
if (config["TransformsMode"].IsDefined()) {
std::string mode = config["TransformsMode"].as<std::string>();
if (mode == "BGR") {
to_rgb = false;
} else if (mode != "RGB") {
std::cerr << "[Init] Only 'RGB' or 'BGR' is supported for TransformsMode"
<< std::endl;
return false;
}
}
// 构建数据处理流
transforms_.Init(config["Transforms"], to_rgb);
// 读入label lis
for (const auto& item : config["_Attributes"]["labels"]) {
int index = labels.size();
labels[index] = item.as<std::string>();
}
return true;
}
bool Model::preprocess(cv::Mat* input_im, ImageBlob* inputs) {
if (!transforms_.Run(input_im, inputs)) {
return false;
}
return true;
}
bool Model::predict(const cv::Mat& im, ClsResult* result) {
inputs_.clear();
if (type == "detector") {
std::cerr << "Loading model is a 'detector', DetResult should be passed to "
"function predict()!"
<< std::endl;
return false;
} else if (type == "segmenter") {
std::cerr << "Loading model is a 'segmenter', SegResult should be passed "
"to function predict()!"
<< std::endl;
return false;
}
// 处理输入图像
inputs_.input_tensor_ = std::move(predictor_->GetInput(0));
cv::Mat im_clone = im.clone();
if (!preprocess(&im_clone, &inputs_)) {
std::cerr << "Preprocess failed!" << std::endl;
return false;
}
;
predictor_->Run();
std::unique_ptr<const Tensor> output_tensor(std::move(predictor_->GetOutput(0)));
const float *outputs_data = output_tensor->mutable_data<float>();
// 对模型输出结果进行后处理
auto ptr = std::max_element(outputs_data, outputs_data+sizeof(outputs_data));
result->category_id = std::distance(outputs_data, ptr);
result->score = *ptr;
result->category = labels[result->category_id];
//for (int i=0;i<sizeof(outputs_data);i++){
// std::cout << labels[i] << std::endl;
// std::cout << outputs_[i] << std::endl;
// }
}
bool Model::predict(const cv::Mat& im, DetResult* result) {
inputs_.clear();
result->clear();
if (type == "classifier") {
std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
"to function predict()!" << std::endl;
return false;
} else if (type == "segmenter") {
std::cerr << "Loading model is a 'segmenter', SegResult should be passed "
"to function predict()!" << std::endl;
return false;
}
inputs_.input_tensor_ = std::move(predictor_->GetInput(0));
cv::Mat im_clone = im.clone();
if (!preprocess(&im_clone, &inputs_)) {
std::cerr << "Preprocess failed!" << std::endl;
return false;
}
int h = inputs_.new_im_size_[0];
int w = inputs_.new_im_size_[1];
if (name == "YOLOv3") {
std::unique_ptr<Tensor> im_size_tensor(std::move(predictor_->GetInput(1)));
const std::vector<int64_t> IM_SIZE_SHAPE = {1,2};
im_size_tensor->Resize(IM_SIZE_SHAPE);
auto *im_size_data = im_size_tensor->mutable_data<int>();
memcpy(im_size_data, inputs_.ori_im_size_.data(), 1*2*sizeof(int));
}
predictor_->Run();
auto output_names = predictor_->GetOutputNames();
auto output_box_tensor = predictor_->GetTensor(output_names[0]);
const float *output_box = output_box_tensor->mutable_data<float>();
std::vector<int64_t> output_box_shape = output_box_tensor->shape();
int size = 1;
for (const auto& i : output_box_shape) {
size *= i;
}
int num_boxes = size / 6;
for (int i = 0; i < num_boxes; ++i) {
Box box;
box.category_id = static_cast<int>(round(output_box[i * 6]));
box.category = labels[box.category_id];
box.score = output_box[i * 6 + 1];
float xmin = output_box[i * 6 + 2];
float ymin = output_box[i * 6 + 3];
float xmax = output_box[i * 6 + 4];
float ymax = output_box[i * 6 + 5];
float w = xmax - xmin + 1;
float h = ymax - ymin + 1;
box.coordinate = {xmin, ymin, w, h};
result->boxes.push_back(std::move(box));
}
return true;
}
bool Model::predict(const cv::Mat& im, SegResult* result) {
result->clear();
inputs_.clear();
if (type == "classifier") {
std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
"to function predict()!" << std::endl;
return false;
} else if (type == "detector") {
std::cerr << "Loading model is a 'detector', DetResult should be passed to "
"function predict()!" << std::endl;
return false;
}
inputs_.input_tensor_ = std::move(predictor_->GetInput(0));
cv::Mat im_clone = im.clone();
if (!preprocess(&im_clone, &inputs_)) {
std::cerr << "Preprocess failed!" << std::endl;
return false;
}
std::cout << "Preprocess is done" << std::endl;
predictor_->Run();
auto output_names = predictor_->GetOutputNames();
auto output_label_tensor = predictor_->GetTensor(output_names[0]);
std::cout << "output0" << output_names[0] << std::endl;
std::cout << "output1" << output_names[1] << std::endl;
const int64_t *label_data = output_label_tensor->mutable_data<int64_t>();
std::vector<int64_t> output_label_shape = output_label_tensor->shape();
int size = 1;
for (const auto& i : output_label_shape) {
size *= i;
result->label_map.shape.push_back(i);
}
result->label_map.data.resize(size);
memcpy(result->label_map.data.data(), label_data, size*sizeof(int64_t));
auto output_score_tensor = predictor_->GetTensor(output_names[1]);
const float *score_data = output_score_tensor->mutable_data<float>();
std::vector<int64_t> output_score_shape = output_score_tensor->shape();
size = 1;
for (const auto& i : output_score_shape) {
size *= i;
result->score_map.shape.push_back(i);
}
result->score_map.data.resize(size);
memcpy(result->score_map.data.data(), score_data, size*sizeof(float));
std::vector<uint8_t> label_map(result->label_map.data.begin(),
result->label_map.data.end());
cv::Mat mask_label(result->label_map.shape[1],
result->label_map.shape[2],
CV_8UC1,
label_map.data());
cv::Mat mask_score(result->score_map.shape[2],
result->score_map.shape[3],
CV_32FC1,
result->score_map.data.data());
int idx = 1;
int len_postprocess = inputs_.im_size_before_resize_.size();
for (std::vector<std::string>::reverse_iterator iter =
inputs_.reshape_order_.rbegin();
iter != inputs_.reshape_order_.rend();
++iter) {
if (*iter == "padding") {
auto before_shape = inputs_.im_size_before_resize_[len_postprocess - idx];
inputs_.im_size_before_resize_.pop_back();
auto padding_w = before_shape[0];
auto padding_h = before_shape[1];
mask_label = mask_label(cv::Rect(0, 0, padding_h, padding_w));
mask_score = mask_score(cv::Rect(0, 0, padding_h, padding_w));
} else if (*iter == "resize") {
auto before_shape = inputs_.im_size_before_resize_[len_postprocess - idx];
inputs_.im_size_before_resize_.pop_back();
auto resize_w = before_shape[0];
auto resize_h = before_shape[1];
cv::resize(mask_label,
mask_label,
cv::Size(resize_h, resize_w),
0,
0,
cv::INTER_NEAREST);
cv::resize(mask_score,
mask_score,
cv::Size(resize_h, resize_w),
0,
0,
cv::INTER_LINEAR);
}
++idx;
}
result->label_map.data.assign(mask_label.begin<uint8_t>(),
mask_label.end<uint8_t>());
result->label_map.shape = {mask_label.rows, mask_label.cols};
result->score_map.data.assign(mask_score.begin<float>(),
mask_score.end<float>());
result->score_map.shape = {mask_score.rows, mask_score.cols};
return true;
}
} // namespce of PaddleX
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include <vector>
#include <math.h>
#include "include/paddlex/transforms.h"
namespace PaddleX {
std::map<std::string, int> interpolations = {{"LINEAR", cv::INTER_LINEAR},
{"NEAREST", cv::INTER_NEAREST},
{"AREA", cv::INTER_AREA},
{"CUBIC", cv::INTER_CUBIC},
{"LANCZOS4", cv::INTER_LANCZOS4}};
bool Normalize::Run(cv::Mat* im, ImageBlob* data){
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
(im->at<cv::Vec3f>(h, w)[0] / 255.0 - mean_[0]) / std_[0];
im->at<cv::Vec3f>(h, w)[1] =
(im->at<cv::Vec3f>(h, w)[1] / 255.0 - mean_[1]) / std_[1];
im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] / 255.0 - mean_[2]) / std_[2];
}
}
return true;
}
float ResizeByShort::GenerateScale(const cv::Mat& im) {
int origin_w = im.cols;
int origin_h = im.rows;
int im_size_max = std::max(origin_w, origin_h);
int im_size_min = std::min(origin_w, origin_h);
float scale =
static_cast<float>(short_size_) / static_cast<float>(im_size_min);
if (max_size_ > 0) {
if (round(scale * im_size_max) > max_size_) {
scale = static_cast<float>(max_size_) / static_cast<float>(im_size_max);
}
}
return scale;
}
bool ResizeByShort::Run(cv::Mat* im, ImageBlob* data) {
data->im_size_before_resize_.push_back({im->rows, im->cols});
data->reshape_order_.push_back("resize");
float scale = GenerateScale(*im);
int width = static_cast<int>(round(scale * im->cols));
int height = static_cast<int>(round(scale * im->rows));
cv::resize(*im, *im, cv::Size(width, height), 0, 0, cv::INTER_LINEAR);
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
data->scale = scale;
return true;
}
bool CenterCrop::Run(cv::Mat* im, ImageBlob* data) {
int height = static_cast<int>(im->rows);
int width = static_cast<int>(im->cols);
if (height < height_ || width < width_) {
std::cerr << "[CenterCrop] Image size less than crop size" << std::endl;
return false;
}
int offset_x = static_cast<int>((width - width_) / 2);
int offset_y = static_cast<int>((height - height_) / 2);
cv::Rect crop_roi(offset_x, offset_y, width_, height_);
*im = (*im)(crop_roi);
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
return true;
}
bool Padding::Run(cv::Mat* im, ImageBlob* data) {
data->im_size_before_resize_.push_back({im->rows, im->cols});
data->reshape_order_.push_back("padding");
int padding_w = 0;
int padding_h = 0;
if (width_ > 1 & height_ > 1) {
padding_w = width_ - im->cols;
padding_h = height_ - im->rows;
} else if (coarsest_stride_ >= 1) {
int h = im->rows;
int w = im->cols;
padding_h =
ceil(h * 1.0 / coarsest_stride_) * coarsest_stride_ - im->rows;
padding_w =
ceil(w * 1.0 / coarsest_stride_) * coarsest_stride_ - im->cols;
}
if (padding_h < 0 || padding_w < 0) {
std::cerr << "[Padding] Computed padding_h=" << padding_h
<< ", padding_w=" << padding_w
<< ", but they should be greater than 0." << std::endl;
return false;
}
cv::Scalar value = cv::Scalar(im_value_[0], im_value_[1], im_value_[2]);
cv::copyMakeBorder(
*im, *im, 0, padding_h, 0, padding_w, cv::BORDER_CONSTANT, value);
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
return true;
}
bool ResizeByLong::Run(cv::Mat* im, ImageBlob* data) {
if (long_size_ <= 0) {
std::cerr << "[ResizeByLong] long_size should be greater than 0"
<< std::endl;
return false;
}
data->im_size_before_resize_.push_back({im->rows, im->cols});
data->reshape_order_.push_back("resize");
int origin_w = im->cols;
int origin_h = im->rows;
int im_size_max = std::max(origin_w, origin_h);
float scale =
static_cast<float>(long_size_) / static_cast<float>(im_size_max);
cv::resize(*im, *im, cv::Size(), scale, scale, cv::INTER_NEAREST);
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
data->scale = scale;
return true;
}
bool Resize::Run(cv::Mat* im, ImageBlob* data) {
if (width_ <= 0 || height_ <= 0) {
std::cerr << "[Resize] width and height should be greater than 0"
<< std::endl;
return false;
}
if (interpolations.count(interp_) <= 0) {
std::cerr << "[Resize] Invalid interpolation method: '" << interp_ << "'"
<< std::endl;
return false;
}
data->im_size_before_resize_.push_back({im->rows, im->cols});
data->reshape_order_.push_back("resize");
cv::resize(
*im, *im, cv::Size(width_, height_), 0, 0, interpolations[interp_]);
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
return true;
}
void Transforms::Init(const YAML::Node& transforms_node, bool to_rgb) {
transforms_.clear();
to_rgb_ = to_rgb;
for (const auto& item : transforms_node) {
std::string name = item.begin()->first.as<std::string>();
std::cout << "trans name: " << name << std::endl;
std::shared_ptr<Transform> transform = CreateTransform(name);
transform->Init(item.begin()->second);
transforms_.push_back(transform);
}
}
std::shared_ptr<Transform> Transforms::CreateTransform(
const std::string& transform_name) {
if (transform_name == "Normalize") {
return std::make_shared<Normalize>();
} else if (transform_name == "ResizeByShort") {
return std::make_shared<ResizeByShort>();
} else if (transform_name == "CenterCrop") {
return std::make_shared<CenterCrop>();
} else if (transform_name == "Resize") {
return std::make_shared<Resize>();
} else if (transform_name == "Padding") {
return std::make_shared<Padding>();
} else if (transform_name == "ResizeByLong") {
return std::make_shared<ResizeByLong>();
} else {
std::cerr << "There's unexpected transform(name='" << transform_name
<< "')." << std::endl;
exit(-1);
}
}
bool Transforms::Run(cv::Mat* im, ImageBlob* data) {
// 按照transforms中预处理算子顺序处理图像
if (to_rgb_) {
cv::cvtColor(*im, *im, cv::COLOR_BGR2RGB);
}
(*im).convertTo(*im, CV_32FC3);
data->ori_im_size_[0] = im->rows;
data->ori_im_size_[1] = im->cols;
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
for (int i = 0; i < transforms_.size(); ++i) {
if (!transforms_[i]->Run(im, data)) {
std::cerr << "Apply transforms to image failed!" << std::endl;
return false;
}
}
// 将图像由NHWC转为NCHW格式
// 同时转为连续的内存块存储到Blob
int height = im->rows;
int width = im->cols;
int channels = im->channels();
const std::vector<int64_t> INPUT_SHAPE = {1, channels, height, width};
data->input_tensor_->Resize(INPUT_SHAPE);
auto *input_data = data->input_tensor_->mutable_data<float>();
for (size_t c = 0; c < channels; c++) {
for (size_t h = 0; h < height; h++) {
for (size_t w = 0; w < width; w++) {
input_data[c * width * height + h * width + w] =
im->at<cv::Vec3f>(h, w)[c];
}
}
}
return true;
}
} // namespace PaddleX
文件已添加
# 树莓派
PaddleX支持通过Paddle-Lite和基于OpenVINO的神经计算棒(NCS2)这两种方式在树莓派上完成预测部署。
## 硬件环境配置
对于尚未安装系统的树莓派首先需要进行系统安装、环境配置等步奏来初始化硬件环境,过程中需要的软硬件如下:
- 硬件:micro SD,显示器,键盘,鼠标
- 软件:Raspbian OS
### Step1:系统安装
- 格式化micro SD卡为FAT格式,Windows和Mac下建议使用[SD Memory Card Formatter](https://www.sdcard.org/downloads/formatter/)工具,Linux下请参考[NOOBS For Raspberry Pi](http://qdosmsq.dunbar-it.co.uk/blog/2013/06/noobs-for-raspberry-pi/)
- 下载NOOBS版本的Raspbian OS [下载地址](https://www.raspberrypi.org/downloads/)并将解压后的文件复制到SD中,插入SD到树莓派上通电,然后将自动安装系统
### Step2:环境配置
- 启用VNC和SSH服务:打开LX终端输入,输入如下命令,选择Interfacing Option然后选择P2 SSH 和 P3 VNC分别打开SSH与VNC。打开后就可以通过SSH或者VNC的方式连接树莓派
```
sudo raspi-config
```
- 更换源:由于树莓派官方源速度很慢,建议在官网查询国内源 [树莓派软件源](https://www.jianshu.com/p/67b9e6ebf8a0)。更换后执行
```
sudo apt-get update
sudo apt-get upgrade
```
## Paddle-Lite部署
基于Paddle-Lite的部署目前可以支持PaddleX的分类、分割与检测模型。部署的流程包括:PaddleX模型转换与转换后的模型部署
**说明**:PaddleX安装请参考[PaddleX](https://paddlex.readthedocs.io/zh_CN/latest/install.html),Paddle-Lite详细资料请参考[Paddle-Lite](https://paddle-lite.readthedocs.io/zh/latest/index.html)
请确保系统已经安装好上述基本软件,并配置好相应环境,**下面所有示例以工作目录 `/root/projects/`演示**
## Paddle-Lite模型转换
将PaddleX模型转换为Paddle-Lite模型,具体请参考[Paddle-Lite模型转换](./export_nb_model.md)
## Paddle-Lite 预测
### Step1 下载PaddleX预测代码
```
mkdir -p /root/projects
cd /root/projects
git clone https://github.com/PaddlePaddle/PaddleX.git
```
**说明**:其中C++预测代码在PaddleX/deploy/raspberry 目录,该目录不依赖任何PaddleX下其他目录,如果需要在python下预测部署请参考[python预测部署](./python.md)
### Step2:Paddle-Lite预编译库下载
提供了下载的opt工具对应的Paddle-Lite在ArmLinux下面的预编译库:[Paddle-Lite(ArmLinux)预编译库](https://bj.bcebos.com/paddlex/deploy/lite/inference_lite_2.6.1_armlinux.tar.bz2)。建议用户使用预编译库,若需要自行编译,在树莓派上LX终端输入
```
git clone https://github.com/PaddlePaddle/Paddle-Lite.git
cd Paddle-Lite
sudo ./lite/tools/build.sh --arm_os=armlinux --arm_abi=armv7hf --arm_lang=gcc --build_extra=ON full_publish
```
预编库位置:`./build.lite.armlinux.armv7hf.gcc/inference_lite_lib.armlinux.armv7hf/cxx`
**注意**:预测库版本需要跟opt版本一致,更多Paddle-Lite编译内容请参考[Paddle-Lite编译](https://paddle-lite.readthedocs.io/zh/latest/user_guides/source_compile.html);更多预编译Paddle-Lite预测库请参考[Paddle-Lite Release Note](https://github.com/PaddlePaddle/Paddle-Lite/releases)
### Step3 软件依赖
提供了依赖软件的预编包或者一键编译,用户不需要单独下载或编译第三方依赖软件。若需要自行编译第三方依赖软件请参考:
- gflags:编译请参考 [编译文档](https://gflags.github.io/gflags/#download)
- glog:编译请参考[编译文档](https://github.com/google/glog)
- opencv: 编译请参考
[编译文档](https://docs.opencv.org/master/d7/d9f/tutorial_linux_install.html)
### Step4: 编译
编译`cmake`的命令在`scripts/build.sh`中,修改LITE_DIR为Paddle-Lite预测库目录,若自行编译第三方依赖软件请根据Step1中编译软件的实际情况修改主要参数,其主要内容说明如下:
```
# Paddle-Lite预编译库的路径
LITE_DIR=/path/to/Paddle-Lite/inference/lib
# gflags预编译库的路径
GFLAGS_DIR=$(pwd)/deps/gflags
# glog预编译库的路径
GLOG_DIR=$(pwd)/deps/glog
# opencv预编译库的路径
OPENCV_DIR=$(pwd)/deps/opencv/
```
执行`build`脚本:
```shell
sh ./scripts/build.sh
```
### Step3: 预测
编译成功后,分类任务的预测可执行程序为`classifier`,分割任务的预测可执行程序为`segmenter`,检测任务的预测可执行程序为`detector`,其主要命令参数说明如下:
| 参数 | 说明 |
| ---- | ---- |
| --model_dir | 模型转换生成的.xml文件路径,请保证模型转换生成的三个文件在同一路径下|
| --image | 要预测的图片文件路径 |
| --image_list | 按行存储图片路径的.txt文件 |
| --thread_num | 预测的线程数,默认值为1 |
| --cfg_dir | PaddleX model 的.yml配置文件 |
| --save_dir | 可视化结果图片保存地址,仅适用于检测和分割任务,默认值为" "既不保存可视化结果 |
### 样例
`样例一`
单张图片分类任务
测试图片 `/path/to/test_img.jpeg`
```shell
./build/classifier --model_dir=/path/to/nb_model
--image=/path/to/test_img.jpeg --cfg_dir=/path/to/PadlleX_model.yml --thread_num=4
```
`样例二`:
多张图片分割任务
预测多个图片`/path/to/image_list.txt`,image_list.txt内容的格式如下:
```
/path/to/images/test_img1.jpeg
/path/to/images/test_img2.jpeg
...
/path/to/images/test_imgn.jpeg
```
```shell
./build/segmenter --model_dir=/path/to/models/nb_model --image_list=/root/projects/images_list.txt --cfg_dir=/path/to/PadlleX_model.yml --save_dir ./output --thread_num=4
```
## 性能测试
### 测试环境:
硬件:Raspberry Pi 3 Model B
系统:raspbian OS
软件:paddle-lite 2.6.1
### 测试结果
单位ms,num表示paddle-lite下使用的线程数
|模型|lite(num=4)|输入图片大小|
| ----| ---- | ----|
|mobilenet-v2|136.19|224*224|
|resnet-50|1131.42|224*224|
|deeplabv3|2162.03|512*512|
|hrnet|6118.23|512*512|
|yolov3-darknet53|4741.15|320*320|
|yolov3-mobilenet|1424.01|320*320|
|densenet121|1144.92|224*224|
|densenet161|2751.57|224*224|
|densenet201|1847.06|224*224|
|HRNet_W18|1753.06|224*224|
|MobileNetV1|177.63|224*224|
|MobileNetV3_large_ssld|133.99|224*224|
|MobileNetV3_small_ssld|53.99|224*224|
|ResNet101|2290.56|224*224|
|ResNet101_vd|2337.51|224*224|
|ResNet101_vd_ssld|3124.49|224*224|
|ShuffleNetV2|115.97|224*224|
|Xception41|1418.29|224*224|
|Xception65|2094.7|224*224|
从测试结果看建议用户在树莓派上使用MobileNetV1-V3,ShuffleNetV2这类型的小型网络
## NCS2部署
树莓派支持通过OpenVINO在NCS2上跑PaddleX模型预测,目前仅支持PaddleX的分类网络,基于NCS2的方式包含Paddle模型转OpenVINO IR以及部署IR在NCS2上进行预测两个步骤。
- 模型转换请参考:[PaddleX模型转换为OpenVINO IR]('./openvino/export_openvino_model.md'),raspbian OS上的OpenVINO不支持模型转换,需要先在host侧转换FP16的IR。
- 预测部署请参考[OpenVINO部署](./openvino/linux.md)中VPU在raspbian OS部署的部分
\ No newline at end of file
# Paddle-Lite模型转换
将Paddle模型转换为Paddle-Lite的nb模型,模型转换主要包括PaddleX转inference model和inference model转Paddle-Lite nb模型
### Step1:导出inference模型
PaddleX模型转Paddle-Lite模型之前需要先把PaddleX模型导出为inference格式模型,导出的模型将包括__model__、__params__和model.yml三个文件名。具体方法请参考[Inference模型导出](../export_model.md)
### Step2:导出Paddle-Lite模型
Paddle-Lite模型需要通过Paddle-Lite的opt工具转出模型,下载并解压: [模型优化工具opt(2.6.1-linux)](https://bj.bcebos.com/paddlex/deploy/Rasoberry/opt.zip),在Linux系统下运行:
``` bash
./opt --model_file=<model_path> \
--param_file=<param_path> \
--valid_targets=arm \
--optimize_out_type=naive_buffer \
--optimize_out=model_output_name
```
| 参数 | 说明 |
| ---- | ---- |
| --model_file | 导出inference模型中包含的网络结构文件:`__model__`所在的路径|
| --param_file | 导出inference模型中包含的参数文件:`__params__`所在的路径|
| --valid_targets |指定模型可执行的backend,这里请指定为`arm`|
| --optimize_out_type | 输出模型类型,目前支持两种类型:protobuf和naive_buffer,其中naive_buffer是一种更轻量级的序列化/反序列化,这里请指定为`naive_buffer`|
若安装了python版本的Paddle-Lite也可以通过如下方式转换
```
./paddle_lite_opt --model_file=<model_path> \
--param_file=<param_path> \
--valid_targets=arm \
--optimize_out_type=naive_buffer \
--optimize_out=model_output_name
```
详细的使用方法和参数含义请参考: [使用opt转化模型](https://paddle-lite.readthedocs.io/zh/latest/user_guides/opt/opt_bin.html),更多opt预编译版本请参考[Paddle-Lite Release Note](https://github.com/PaddlePaddle/Paddle-Lite/releases)
**注意**:opt版本需要跟预测库版本保持一致,如使2.6.0的python版预测库,请从上面Release Note中下载2.6.0版本的opt转换模型
\ No newline at end of file
# Python预测部署
文档说明了在树莓派上使用python版本的Paddle-Lite进行PaddleX模型好的预测部署,Paddle-Lite python版本的预测库下载,用户也可以下载whl文件进行安装[Paddle-Lite_2.6.0_python](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.6.0/armlinux_python_installer.zip),更多版本请参考[Paddle-Lite Release Note](https://github.com/PaddlePaddle/Paddle-Lite/releases)
```
python -m pip install paddlelite
```
部署前需要先将PaddleX模型转换为Paddle-Lite的nb模型,具体请参考[Paddle-Lite模型转换](./export_nb_model.md)
## 前置条件
* Python 3.6+
* Paddle-Lite_python 2.6.0+
请确保系统已经安装好上述基本软件,**下面所有示例以工作目录 `/root/projects/`演示**
## 预测部署
运行/root/projects/PaddleX/deploy/raspberry/python目录下demo.py文件可以进行预测,其命令参数说明如下:
| 参数 | 说明 |
| ---- | ---- |
| --model_dir | 模型转换生成的.xml文件路径,请保证模型转换生成的三个文件在同一路径下|
| --img | 要预测的图片文件路径 |
| --image_list | 按行存储图片路径的.txt文件 |
| --cfg_dir | PaddleX model 的.yml配置文件 |
| --thread_num | 预测的线程数, 默认值为1 |
| --input_shape | 模型输入中图片输入的大小[N,C,H.W] |
### 样例
`样例一`
测试图片 `/path/to/test_img.jpeg`
```
cd /root/projects/python
python demo.py --model_dir /path/to/openvino_model --img /path/to/test_img.jpeg --cfg_dir /path/to/PadlleX_model.yml --thread_num 4 --input_shape [1,3,224,224]
```
样例二`:
预测多个图片`/path/to/image_list.txt`,image_list.txt内容的格式如下:
```
/path/to/images/test_img1.jpeg
/path/to/images/test_img2.jpeg
...
/path/to/images/test_imgn.jpeg
```
```
cd /root/projects/python
python demo.py --model_dir /path/to/models/openvino_model --image_list /root/projects/images_list.txt --cfg_dir=/path/to/PadlleX_model.yml --thread_num 4 --input_shape [1,3,224,224]
```
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册