提交 b99dfb81 编写于 作者: C Channingss

update code

上级 771c6709
...@@ -48,16 +48,17 @@ class Model { ...@@ -48,16 +48,17 @@ class Model {
bool load_config(const std::string& model_dir); bool load_config(const std::string& model_dir);
bool preprocess(cv::Mat* input_im, ImageBlob* blob); bool preprocess(cv::Mat* input_im);
bool predict(cv::Mat* im, ClsResult* result); bool predict(const cv::Mat& im, ClsResult* result);
std::string type; std::string type;
std::string name; std::string name;
std::vector<std::string> labels; std::vector<std::string> labels;
Transforms transforms_; Transforms transforms_;
Blob::Ptr inputs_; Blob::Ptr inputs_;
Blob::Ptr output_ Blob::Ptr output_;
CNNNetwork network_; CNNNetwork network_;
ExecutableNetwork executable_network_;
}; };
} // namespce of PaddleX } // namespce of PaddleX
...@@ -35,7 +35,7 @@ namespace PaddleX { ...@@ -35,7 +35,7 @@ namespace PaddleX {
class Transform { class Transform {
public: public:
virtual void Init(const YAML::Node& item) = 0; virtual void Init(const YAML::Node& item) = 0;
virtual bool Run(cv::Mat* im, ImageBlob* data) = 0; virtual bool Run(cv::Mat* im) = 0;
}; };
class Normalize : public Transform { class Normalize : public Transform {
...@@ -45,39 +45,32 @@ class Normalize : public Transform { ...@@ -45,39 +45,32 @@ class Normalize : public Transform {
std_ = item["std"].as<std::vector<float>>(); std_ = item["std"].as<std::vector<float>>();
} }
virtual bool Run(cv::Mat* im, ImageBlob* data); virtual bool Run(cv::Mat* im);
private: private:
std::vector<float> mean_; std::vector<float> mean_;
std::vector<float> std_; std::vector<float> std_;
}; };
class ResizeByShort : public Transform {
class Resize : public Transform {
public: public:
virtual void Init(const YAML::Node& item) { virtual void Init(const YAML::Node& item) {
if (item["target_size"].IsScalar()) { short_size_ = item["short_size"].as<int>();
height_ = item["target_size"].as<int>(); if (item["max_size"].IsDefined()) {
width_ = item["target_size"].as<int>(); max_size_ = item["max_size"].as<int>();
interp_ = item["interp"].as<std::string>(); } else {
} else if (item["target_size"].IsSequence()) { max_size_ = -1;
std::vector<int> target_size = item["target_size"].as<std::vector<int>>();
width_ = target_size[0];
height_ = target_size[1];
}
if (height_ <= 0 || width_ <= 0) {
std::cerr << "[Resize] target_size should greater than 0" << std::endl;
exit(-1);
}
} }
virtual bool Run(cv::Mat* im, ImageBlob* data); };
virtual bool Run(cv::Mat* im);
private: private:
int height_; float GenerateScale(const cv::Mat& im);
int width_; int short_size_;
std::string interp_; int max_size_;
}; };
class CenterCrop : public Transform { class CenterCrop : public Transform {
public: public:
virtual void Init(const YAML::Node& item) { virtual void Init(const YAML::Node& item) {
...@@ -90,7 +83,7 @@ class CenterCrop : public Transform { ...@@ -90,7 +83,7 @@ class CenterCrop : public Transform {
height_ = crop_size[1]; height_ = crop_size[1];
} }
} }
virtual bool Run(cv::Mat* im, ImageBlob* data); virtual bool Run(cv::Mat* im);
private: private:
int height_; int height_;
...@@ -101,7 +94,7 @@ class Transforms { ...@@ -101,7 +94,7 @@ class Transforms {
public: public:
void Init(const YAML::Node& node, bool to_rgb = true); void Init(const YAML::Node& node, bool to_rgb = true);
std::shared_ptr<Transform> CreateTransform(const std::string& name); std::shared_ptr<Transform> CreateTransform(const std::string& name);
bool Run(cv::Mat* im, Blob::ptr data); bool Run(cv::Mat* im, Blob::Ptr blob);
private: private:
std::vector<std::shared_ptr<Transform>> transforms_; std::vector<std::shared_ptr<Transform>> transforms_;
......
...@@ -55,7 +55,7 @@ int main(int argc, char** argv) { ...@@ -55,7 +55,7 @@ int main(int argc, char** argv) {
while (getline(inf, image_path)) { while (getline(inf, image_path)) {
PaddleX::ClsResult result; PaddleX::ClsResult result;
cv::Mat im = cv::imread(image_path, 1); cv::Mat im = cv::imread(image_path, 1);
model.predict(&im, &result); model.predict(im, &result);
std::cout << "Predict label: " << result.category std::cout << "Predict label: " << result.category
<< ", label_id:" << result.category_id << ", label_id:" << result.category_id
<< ", score: " << result.score << std::endl; << ", score: " << result.score << std::endl;
...@@ -63,7 +63,7 @@ int main(int argc, char** argv) { ...@@ -63,7 +63,7 @@ int main(int argc, char** argv) {
} else { } else {
PaddleX::ClsResult result; PaddleX::ClsResult result;
cv::Mat im = cv::imread(FLAGS_image, 1); cv::Mat im = cv::imread(FLAGS_image, 1);
model.predict(&im, &result); model.predict(im, &result);
std::cout << "Predict label: " << result.category std::cout << "Predict label: " << result.category
<< ", label_id:" << result.category_id << ", label_id:" << result.category_id
<< ", score: " << result.score << std::endl; << ", score: " << result.score << std::endl;
......
...@@ -29,7 +29,7 @@ void Model::create_predictor(const std::string& model_dir, ...@@ -29,7 +29,7 @@ void Model::create_predictor(const std::string& model_dir,
input_info->getPreProcess().setResizeAlgorithm(RESIZE_BILINEAR); input_info->getPreProcess().setResizeAlgorithm(RESIZE_BILINEAR);
input_info->setLayout(Layout::NCHW); input_info->setLayout(Layout::NCHW);
input_info->setPrecision(Precision::FP32); input_info->setPrecision(Precision::FP32);
executable_network_ = ie.LoadNetwork(network_, device);
load_config(cfg_dir); load_config(cfg_dir);
} }
...@@ -56,15 +56,14 @@ bool Model::load_config(const std::string& cfg_dir) { ...@@ -56,15 +56,14 @@ bool Model::load_config(const std::string& cfg_dir) {
return true; return true;
} }
bool Model::preprocess(cv::Mat* input_im, ImageBlob* blob) { bool Model::preprocess(cv::Mat* input_im) {
if (!transforms_.Run(input_im, &inputs_)) { if (!transforms_.Run(input_im, inputs_)) {
return false; return false;
} }
return true; return true;
} }
bool Model::predict(const cv::Mat& im, ClsResult* result) { bool Model::predict(const cv::Mat& im, ClsResult* result) {
inputs_.clear();
if (type == "detector") { if (type == "detector") {
std::cerr << "Loading model is a 'detector', DetResult should be passed to " std::cerr << "Loading model is a 'detector', DetResult should be passed to "
"function predict()!" "function predict()!"
...@@ -77,14 +76,12 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) { ...@@ -77,14 +76,12 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) {
return false; return false;
} }
// 处理输入图像 // 处理输入图像
InferRequest infer_request = executable_network_.CreateInferRequest();
executable_network = ie.LoadNetwork(network_, device);
InferRequest infer_request = executable_network.CreateInferRequest();
std::string input_name = network_.getInputsInfo().begin()->first; std::string input_name = network_.getInputsInfo().begin()->first;
input_ = infer_request.GetBlob(input_name); inputs_ = infer_request.GetBlob(input_name);
auto im_clone = im.clone(); auto im_clone = im.clone();
if (!preprocess(&im_clone, inputs_)) { if (!preprocess(&im_clone)) {
std::cerr << "Preprocess failed!" << std::endl; std::cerr << "Preprocess failed!" << std::endl;
return false; return false;
} }
...@@ -93,12 +90,12 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) { ...@@ -93,12 +90,12 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) {
std::string output_name = network_.getOutputsInfo().begin()->first; std::string output_name = network_.getOutputsInfo().begin()->first;
output_ = infer_request.GetBlob(output_name); output_ = infer_request.GetBlob(output_name);
MemoryBlob::CPtr moutput = as<MemoryBlob>(output); MemoryBlob::CPtr moutput = as<MemoryBlob>(output_);
auto moutputHolder = moutput->rmap(); auto moutputHolder = moutput->rmap();
float* outputs_data = moutputHolder.as<float *>(); float* outputs_data = moutputHolder.as<float *>();
// 对模型输出结果进行后处理 // 对模型输出结果进行后处理
auto ptr = std::max_element(outputs_data, outputs_data+sizeof(outputs_)); auto ptr = std::max_element(outputs_data, outputs_data+sizeof(outputs_data));
result->category_id = std::distance(outputs_data, ptr); result->category_id = std::distance(outputs_data, ptr);
result->score = *ptr; result->score = *ptr;
result->category = labels[result->category_id]; result->category = labels[result->category_id];
......
...@@ -26,7 +26,7 @@ std::map<std::string, int> interpolations = {{"LINEAR", cv::INTER_LINEAR}, ...@@ -26,7 +26,7 @@ std::map<std::string, int> interpolations = {{"LINEAR", cv::INTER_LINEAR},
{"CUBIC", cv::INTER_CUBIC}, {"CUBIC", cv::INTER_CUBIC},
{"LANCZOS4", cv::INTER_LANCZOS4}}; {"LANCZOS4", cv::INTER_LANCZOS4}};
bool Normalize::Run(cv::Mat* im, ImageBlob* data) { bool Normalize::Run(cv::Mat* im){
for (int h = 0; h < im->rows; h++) { for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) { for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] = im->at<cv::Vec3f>(h, w)[0] =
...@@ -40,7 +40,7 @@ bool Normalize::Run(cv::Mat* im, ImageBlob* data) { ...@@ -40,7 +40,7 @@ bool Normalize::Run(cv::Mat* im, ImageBlob* data) {
return true; return true;
} }
bool CenterCrop::Run(cv::Mat* im, ImageBlob* data) { bool CenterCrop::Run(cv::Mat* im) {
int height = static_cast<int>(im->rows); int height = static_cast<int>(im->rows);
int width = static_cast<int>(im->cols); int width = static_cast<int>(im->cols);
if (height < height_ || width < width_) { if (height < height_ || width < width_) {
...@@ -51,30 +51,30 @@ bool CenterCrop::Run(cv::Mat* im, ImageBlob* data) { ...@@ -51,30 +51,30 @@ bool CenterCrop::Run(cv::Mat* im, ImageBlob* data) {
int offset_y = static_cast<int>((height - height_) / 2); int offset_y = static_cast<int>((height - height_) / 2);
cv::Rect crop_roi(offset_x, offset_y, width_, height_); cv::Rect crop_roi(offset_x, offset_y, width_, height_);
*im = (*im)(crop_roi); *im = (*im)(crop_roi);
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
return true; return true;
} }
bool Resize::Run(cv::Mat* im, ImageBlob* data) { float ResizeByShort::GenerateScale(const cv::Mat& im) {
if (width_ <= 0 || height_ <= 0) { int origin_w = im.cols;
std::cerr << "[Resize] width and height should be greater than 0" int origin_h = im.rows;
<< std::endl; int im_size_max = std::max(origin_w, origin_h);
return false; int im_size_min = std::min(origin_w, origin_h);
float scale =
static_cast<float>(short_size_) / static_cast<float>(im_size_min);
if (max_size_ > 0) {
if (round(scale * im_size_max) > max_size_) {
scale = static_cast<float>(max_size_) / static_cast<float>(im_size_max);
} }
if (interpolations.count(interp_) <= 0) {
std::cerr << "[Resize] Invalid interpolation method: '" << interp_ << "'"
<< std::endl;
return false;
} }
data->im_size_before_resize_.push_back({im->rows, im->cols}); return scale;
data->reshape_order_.push_back("resize"); }
cv::resize( bool ResizeByShort::Run(cv::Mat* im) {
*im, *im, cv::Size(width_, height_), 0, 0, interpolations[interp_]); float scale = GenerateScale(*im);
data->new_im_size_[0] = im->rows; int width = static_cast<int>(scale * im->cols);
data->new_im_size_[1] = im->cols; int height = static_cast<int>(scale * im->rows);
cv::resize(*im, *im, cv::Size(width, height), 0, 0, cv::INTER_LINEAR);
return true; return true;
} }
...@@ -96,8 +96,8 @@ std::shared_ptr<Transform> Transforms::CreateTransform( ...@@ -96,8 +96,8 @@ std::shared_ptr<Transform> Transforms::CreateTransform(
return std::make_shared<Normalize>(); return std::make_shared<Normalize>();
} else if (transform_name == "CenterCrop") { } else if (transform_name == "CenterCrop") {
return std::make_shared<CenterCrop>(); return std::make_shared<CenterCrop>();
} else if (transform_name == "Resize") { } else if (transform_name == "ResizeByShort") {
return std::make_shared<Resize>(); return std::make_shared<ResizeByShort>();
} else { } else {
std::cerr << "There's unexpected transform(name='" << transform_name std::cerr << "There's unexpected transform(name='" << transform_name
<< "')." << std::endl; << "')." << std::endl;
...@@ -105,7 +105,7 @@ std::shared_ptr<Transform> Transforms::CreateTransform( ...@@ -105,7 +105,7 @@ std::shared_ptr<Transform> Transforms::CreateTransform(
} }
} }
bool Transforms::Run(cv::Mat* im, Blob::ptr data) { bool Transforms::Run(cv::Mat* im, Blob::Ptr blob) {
// 按照transforms中预处理算子顺序处理图像 // 按照transforms中预处理算子顺序处理图像
if (to_rgb_) { if (to_rgb_) {
cv::cvtColor(*im, *im, cv::COLOR_BGR2RGB); cv::cvtColor(*im, *im, cv::COLOR_BGR2RGB);
...@@ -113,7 +113,7 @@ bool Transforms::Run(cv::Mat* im, Blob::ptr data) { ...@@ -113,7 +113,7 @@ bool Transforms::Run(cv::Mat* im, Blob::ptr data) {
(*im).convertTo(*im, CV_32FC3); (*im).convertTo(*im, CV_32FC3);
for (int i = 0; i < transforms_.size(); ++i) { for (int i = 0; i < transforms_.size(); ++i) {
if (!transforms_[i]->Run(im, data)) { if (!transforms_[i]->Run(im)) {
std::cerr << "Apply transforms to image failed!" << std::endl; std::cerr << "Apply transforms to image failed!" << std::endl;
return false; return false;
} }
...@@ -121,7 +121,7 @@ bool Transforms::Run(cv::Mat* im, Blob::ptr data) { ...@@ -121,7 +121,7 @@ bool Transforms::Run(cv::Mat* im, Blob::ptr data) {
// 将图像由NHWC转为NCHW格式 // 将图像由NHWC转为NCHW格式
// 同时转为连续的内存块存储到Blob // 同时转为连续的内存块存储到Blob
SizeVector blobSize = data_->getTensorDesc().getDims(); SizeVector blobSize = blob->getTensorDesc().getDims();
const size_t width = blobSize[3]; const size_t width = blobSize[3];
const size_t height = blobSize[2]; const size_t height = blobSize[2];
const size_t channels = blobSize[1]; const size_t channels = blobSize[1];
...@@ -132,7 +132,7 @@ bool Transforms::Run(cv::Mat* im, Blob::ptr data) { ...@@ -132,7 +132,7 @@ bool Transforms::Run(cv::Mat* im, Blob::ptr data) {
for (size_t h = 0; h < height; h++) { for (size_t h = 0; h < height; h++) {
for (size_t w = 0; w < width; w++) { for (size_t w = 0; w < width; w++) {
blob_data[c * width * height + h * width + w] = blob_data[c * width * height + h * width + w] =
im.at<cv::Vec3f>(h, w)[c]; im->at<cv::Vec3f>(h, w)[c];
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册