未验证 提交 dd87d289 编写于 作者: W wangguanzhong 提交者: GitHub

fix_inference_in_ttfnet (#1063)

上级 2cc0c157
...@@ -39,6 +39,8 @@ class ImageBlob { ...@@ -39,6 +39,8 @@ class ImageBlob {
std::vector<float> ori_im_size_f_; std::vector<float> ori_im_size_f_;
// Evaluation image width and height // Evaluation image width and height
std::vector<float> eval_im_size_f_; std::vector<float> eval_im_size_f_;
// Scale factor for image size to origin image size
std::vector<float> scale_factor_f_;
}; };
// Abstraction of preprocessing opration class // Abstraction of preprocessing opration class
......
...@@ -140,7 +140,7 @@ void ObjectDetector::Postprocess( ...@@ -140,7 +140,7 @@ void ObjectDetector::Postprocess(
int ymax = (output_data_[5 + j * 6] * rh); int ymax = (output_data_[5 + j * 6] * rh);
int wd = xmax - xmin; int wd = xmax - xmin;
int hd = ymax - ymin; int hd = ymax - ymin;
if (score > threshold_) { if (score > threshold_ && class_id > -1) {
ObjectResult result_item; ObjectResult result_item;
result_item.rect = {xmin, xmax, ymin, ymax}; result_item.rect = {xmin, xmax, ymin, ymax};
result_item.class_id = class_id; result_item.class_id = class_id;
...@@ -172,6 +172,9 @@ void ObjectDetector::Predict(const cv::Mat& im, ...@@ -172,6 +172,9 @@ void ObjectDetector::Predict(const cv::Mat& im,
} else if (tensor_name == "im_shape") { } else if (tensor_name == "im_shape") {
in_tensor->Reshape({1, 3}); in_tensor->Reshape({1, 3});
in_tensor->copy_from_cpu(inputs_.ori_im_size_f_.data()); in_tensor->copy_from_cpu(inputs_.ori_im_size_f_.data());
} else if (tensor_name == "scale_factor") {
in_tensor->Reshape({1, 4});
in_tensor->copy_from_cpu(inputs_.scale_factor_f_.data());
} }
} }
// Run predictor // Run predictor
......
...@@ -78,6 +78,12 @@ void Resize::Run(cv::Mat* im, ImageBlob* data) { ...@@ -78,6 +78,12 @@ void Resize::Run(cv::Mat* im, ImageBlob* data) {
static_cast<float>(im->cols), static_cast<float>(im->cols),
resize_scale.first resize_scale.first
}; };
data->scale_factor_f_ = {
resize_scale.first,
resize_scale.second,
resize_scale.first,
resize_scale.second
};
} }
std::pair<float, float> Resize::GenerateScale(const cv::Mat& im) { std::pair<float, float> Resize::GenerateScale(const cv::Mat& im) {
......
...@@ -178,6 +178,8 @@ def main(): ...@@ -178,6 +178,8 @@ def main():
for k, v in zip(keys, outs) for k, v in zip(keys, outs)
} }
logger.info('Infer iter {}'.format(iter_id)) logger.info('Infer iter {}'.format(iter_id))
if 'TTFNet' in cfg.architecture:
res['bbox'][1].append([len(res['bbox'][0])])
bbox_results = None bbox_results = None
mask_results = None mask_results = None
...@@ -256,4 +258,4 @@ if __name__ == '__main__': ...@@ -256,4 +258,4 @@ if __name__ == '__main__':
default="vdl_log_dir/image", default="vdl_log_dir/image",
help='VisualDL logging directory for image.') help='VisualDL logging directory for image.')
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
main() main()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册