提交 b99dfb81 编写于 作者: C Channingss

update code

上级 771c6709
......@@ -48,16 +48,17 @@ class Model {
bool load_config(const std::string& model_dir);
bool preprocess(cv::Mat* input_im, ImageBlob* blob);
bool preprocess(cv::Mat* input_im);
bool predict(cv::Mat* im, ClsResult* result);
bool predict(const cv::Mat& im, ClsResult* result);
std::string type;
std::string name;
std::vector<std::string> labels;
Transforms transforms_;
Blob::Ptr inputs_;
Blob::Ptr output_
Blob::Ptr output_;
CNNNetwork network_;
ExecutableNetwork executable_network_;
};
} // namespce of PaddleX
......@@ -35,7 +35,7 @@ namespace PaddleX {
class Transform {
public:
virtual void Init(const YAML::Node& item) = 0;
virtual bool Run(cv::Mat* im, ImageBlob* data) = 0;
virtual bool Run(cv::Mat* im) = 0;
};
class Normalize : public Transform {
......@@ -45,39 +45,32 @@ class Normalize : public Transform {
std_ = item["std"].as<std::vector<float>>();
}
virtual bool Run(cv::Mat* im, ImageBlob* data);
virtual bool Run(cv::Mat* im);
private:
std::vector<float> mean_;
std::vector<float> std_;
};
class Resize : public Transform {
class ResizeByShort : public Transform {
public:
virtual void Init(const YAML::Node& item) {
if (item["target_size"].IsScalar()) {
height_ = item["target_size"].as<int>();
width_ = item["target_size"].as<int>();
interp_ = item["interp"].as<std::string>();
} else if (item["target_size"].IsSequence()) {
std::vector<int> target_size = item["target_size"].as<std::vector<int>>();
width_ = target_size[0];
height_ = target_size[1];
}
if (height_ <= 0 || width_ <= 0) {
std::cerr << "[Resize] target_size should greater than 0" << std::endl;
exit(-1);
short_size_ = item["short_size"].as<int>();
if (item["max_size"].IsDefined()) {
max_size_ = item["max_size"].as<int>();
} else {
max_size_ = -1;
}
}
virtual bool Run(cv::Mat* im, ImageBlob* data);
};
virtual bool Run(cv::Mat* im);
private:
int height_;
int width_;
std::string interp_;
float GenerateScale(const cv::Mat& im);
int short_size_;
int max_size_;
};
class CenterCrop : public Transform {
public:
virtual void Init(const YAML::Node& item) {
......@@ -90,7 +83,7 @@ class CenterCrop : public Transform {
height_ = crop_size[1];
}
}
virtual bool Run(cv::Mat* im, ImageBlob* data);
virtual bool Run(cv::Mat* im);
private:
int height_;
......@@ -101,7 +94,7 @@ class Transforms {
public:
void Init(const YAML::Node& node, bool to_rgb = true);
std::shared_ptr<Transform> CreateTransform(const std::string& name);
bool Run(cv::Mat* im, Blob::ptr data);
bool Run(cv::Mat* im, Blob::Ptr blob);
private:
std::vector<std::shared_ptr<Transform>> transforms_;
......
......@@ -55,7 +55,7 @@ int main(int argc, char** argv) {
while (getline(inf, image_path)) {
PaddleX::ClsResult result;
cv::Mat im = cv::imread(image_path, 1);
model.predict(&im, &result);
model.predict(im, &result);
std::cout << "Predict label: " << result.category
<< ", label_id:" << result.category_id
<< ", score: " << result.score << std::endl;
......@@ -63,7 +63,7 @@ int main(int argc, char** argv) {
} else {
PaddleX::ClsResult result;
cv::Mat im = cv::imread(FLAGS_image, 1);
model.predict(&im, &result);
model.predict(im, &result);
std::cout << "Predict label: " << result.category
<< ", label_id:" << result.category_id
<< ", score: " << result.score << std::endl;
......
......@@ -29,7 +29,7 @@ void Model::create_predictor(const std::string& model_dir,
input_info->getPreProcess().setResizeAlgorithm(RESIZE_BILINEAR);
input_info->setLayout(Layout::NCHW);
input_info->setPrecision(Precision::FP32);
executable_network_ = ie.LoadNetwork(network_, device);
load_config(cfg_dir);
}
......@@ -56,15 +56,14 @@ bool Model::load_config(const std::string& cfg_dir) {
return true;
}
bool Model::preprocess(cv::Mat* input_im, ImageBlob* blob) {
if (!transforms_.Run(input_im, &inputs_)) {
bool Model::preprocess(cv::Mat* input_im) {
if (!transforms_.Run(input_im, inputs_)) {
return false;
}
return true;
}
bool Model::predict(const cv::Mat& im, ClsResult* result) {
inputs_.clear();
if (type == "detector") {
std::cerr << "Loading model is a 'detector', DetResult should be passed to "
"function predict()!"
......@@ -77,14 +76,12 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) {
return false;
}
// 处理输入图像
executable_network = ie.LoadNetwork(network_, device);
InferRequest infer_request = executable_network.CreateInferRequest();
InferRequest infer_request = executable_network_.CreateInferRequest();
std::string input_name = network_.getInputsInfo().begin()->first;
input_ = infer_request.GetBlob(input_name);
inputs_ = infer_request.GetBlob(input_name);
auto im_clone = im.clone();
if (!preprocess(&im_clone, inputs_)) {
if (!preprocess(&im_clone)) {
std::cerr << "Preprocess failed!" << std::endl;
return false;
}
......@@ -93,12 +90,12 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) {
std::string output_name = network_.getOutputsInfo().begin()->first;
output_ = infer_request.GetBlob(output_name);
MemoryBlob::CPtr moutput = as<MemoryBlob>(output);
MemoryBlob::CPtr moutput = as<MemoryBlob>(output_);
auto moutputHolder = moutput->rmap();
float* outputs_data = moutputHolder.as<float *>();
// 对模型输出结果进行后处理
auto ptr = std::max_element(outputs_data, outputs_data+sizeof(outputs_));
auto ptr = std::max_element(outputs_data, outputs_data+sizeof(outputs_data));
result->category_id = std::distance(outputs_data, ptr);
result->score = *ptr;
result->category = labels[result->category_id];
......
......@@ -26,7 +26,7 @@ std::map<std::string, int> interpolations = {{"LINEAR", cv::INTER_LINEAR},
{"CUBIC", cv::INTER_CUBIC},
{"LANCZOS4", cv::INTER_LANCZOS4}};
bool Normalize::Run(cv::Mat* im, ImageBlob* data) {
bool Normalize::Run(cv::Mat* im){
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
......@@ -40,7 +40,7 @@ bool Normalize::Run(cv::Mat* im, ImageBlob* data) {
return true;
}
bool CenterCrop::Run(cv::Mat* im, ImageBlob* data) {
bool CenterCrop::Run(cv::Mat* im) {
int height = static_cast<int>(im->rows);
int width = static_cast<int>(im->cols);
if (height < height_ || width < width_) {
......@@ -51,30 +51,30 @@ bool CenterCrop::Run(cv::Mat* im, ImageBlob* data) {
int offset_y = static_cast<int>((height - height_) / 2);
cv::Rect crop_roi(offset_x, offset_y, width_, height_);
*im = (*im)(crop_roi);
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
return true;
}
bool Resize::Run(cv::Mat* im, ImageBlob* data) {
if (width_ <= 0 || height_ <= 0) {
std::cerr << "[Resize] width and height should be greater than 0"
<< std::endl;
return false;
}
if (interpolations.count(interp_) <= 0) {
std::cerr << "[Resize] Invalid interpolation method: '" << interp_ << "'"
<< std::endl;
return false;
float ResizeByShort::GenerateScale(const cv::Mat& im) {
int origin_w = im.cols;
int origin_h = im.rows;
int im_size_max = std::max(origin_w, origin_h);
int im_size_min = std::min(origin_w, origin_h);
float scale =
static_cast<float>(short_size_) / static_cast<float>(im_size_min);
if (max_size_ > 0) {
if (round(scale * im_size_max) > max_size_) {
scale = static_cast<float>(max_size_) / static_cast<float>(im_size_max);
}
}
data->im_size_before_resize_.push_back({im->rows, im->cols});
data->reshape_order_.push_back("resize");
return scale;
}
cv::resize(
*im, *im, cv::Size(width_, height_), 0, 0, interpolations[interp_]);
data->new_im_size_[0] = im->rows;
data->new_im_size_[1] = im->cols;
bool ResizeByShort::Run(cv::Mat* im) {
float scale = GenerateScale(*im);
int width = static_cast<int>(scale * im->cols);
int height = static_cast<int>(scale * im->rows);
cv::resize(*im, *im, cv::Size(width, height), 0, 0, cv::INTER_LINEAR);
return true;
}
......@@ -96,8 +96,8 @@ std::shared_ptr<Transform> Transforms::CreateTransform(
return std::make_shared<Normalize>();
} else if (transform_name == "CenterCrop") {
return std::make_shared<CenterCrop>();
} else if (transform_name == "Resize") {
return std::make_shared<Resize>();
} else if (transform_name == "ResizeByShort") {
return std::make_shared<ResizeByShort>();
} else {
std::cerr << "There's unexpected transform(name='" << transform_name
<< "')." << std::endl;
......@@ -105,7 +105,7 @@ std::shared_ptr<Transform> Transforms::CreateTransform(
}
}
bool Transforms::Run(cv::Mat* im, Blob::ptr data) {
bool Transforms::Run(cv::Mat* im, Blob::Ptr blob) {
// 按照transforms中预处理算子顺序处理图像
if (to_rgb_) {
cv::cvtColor(*im, *im, cv::COLOR_BGR2RGB);
......@@ -113,7 +113,7 @@ bool Transforms::Run(cv::Mat* im, Blob::ptr data) {
(*im).convertTo(*im, CV_32FC3);
for (int i = 0; i < transforms_.size(); ++i) {
if (!transforms_[i]->Run(im, data)) {
if (!transforms_[i]->Run(im)) {
std::cerr << "Apply transforms to image failed!" << std::endl;
return false;
}
......@@ -121,7 +121,7 @@ bool Transforms::Run(cv::Mat* im, Blob::ptr data) {
// 将图像由NHWC转为NCHW格式
// 同时转为连续的内存块存储到Blob
SizeVector blobSize = data_->getTensorDesc().getDims();
SizeVector blobSize = blob->getTensorDesc().getDims();
const size_t width = blobSize[3];
const size_t height = blobSize[2];
const size_t channels = blobSize[1];
......@@ -132,7 +132,7 @@ bool Transforms::Run(cv::Mat* im, Blob::ptr data) {
for (size_t h = 0; h < height; h++) {
for (size_t w = 0; w < width; w++) {
blob_data[c * width * height + h * width + w] =
im.at<cv::Vec3f>(h, w)[c];
im->at<cv::Vec3f>(h, w)[c];
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册