未验证 提交 020e1cd6 编写于 作者: W wangguanzhong 提交者: GitHub

support mask rcnn bs>2 (#5470)

* support mask rcnn bs>2

* fix resize op in c++ deploy

* fix preprocess in mask rcnn

* fix cascade mask rcnn when box is empty

* update merge batch
上级 28f79157
...@@ -120,6 +120,10 @@ class ConfigPaser { ...@@ -120,6 +120,10 @@ class ConfigPaser {
} }
} }
if (config["mask"].IsDefined()) {
mask_ = config["mask"].as<bool>();
}
return true; return true;
} }
std::string mode_; std::string mode_;
...@@ -132,6 +136,7 @@ class ConfigPaser { ...@@ -132,6 +136,7 @@ class ConfigPaser {
std::vector<int> fpn_stride_; std::vector<int> fpn_stride_;
bool use_dynamic_shape_; bool use_dynamic_shape_;
float conf_thresh_; float conf_thresh_;
bool mask_ = false;
}; };
} // namespace PaddleDetection } // namespace PaddleDetection
...@@ -114,6 +114,7 @@ class ObjectDetector { ...@@ -114,6 +114,7 @@ class ObjectDetector {
std::vector<PaddleDetection::ObjectResult>* result, std::vector<PaddleDetection::ObjectResult>* result,
std::vector<int> bbox_num, std::vector<int> bbox_num,
std::vector<float> output_data_, std::vector<float> output_data_,
std::vector<int> output_mask_data_,
bool is_rbox); bool is_rbox);
std::shared_ptr<Predictor> predictor_; std::shared_ptr<Predictor> predictor_;
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
#pragma once #pragma once
#include <string> #include <algorithm>
#include <vector>
#include <memory>
#include <utility>
#include <ctime> #include <ctime>
#include <memory>
#include <numeric> #include <numeric>
#include <algorithm> #include <string>
#include <utility>
#include <vector>
namespace PaddleDetection { namespace PaddleDetection {
...@@ -32,6 +32,8 @@ struct ObjectResult { ...@@ -32,6 +32,8 @@ struct ObjectResult {
int class_id; int class_id;
// Confidence of detected object // Confidence of detected object
float confidence; float confidence;
// Mask of detected object
std::vector<int> mask;
}; };
void nms(std::vector<ObjectResult> &input_boxes, float nms_threshold); void nms(std::vector<ObjectResult> &input_boxes, float nms_threshold);
......
...@@ -41,12 +41,17 @@ void ObjectDetector::LoadModel(const std::string &model_dir, ...@@ -41,12 +41,17 @@ void ObjectDetector::LoadModel(const std::string &model_dir,
} else if (run_mode == "trt_int8") { } else if (run_mode == "trt_int8") {
precision = paddle_infer::Config::Precision::kInt8; precision = paddle_infer::Config::Precision::kInt8;
} else { } else {
printf("run_mode should be 'paddle', 'trt_fp32', 'trt_fp16' or " printf(
"run_mode should be 'paddle', 'trt_fp32', 'trt_fp16' or "
"'trt_int8'"); "'trt_int8'");
} }
// set tensorrt // set tensorrt
config.EnableTensorRtEngine(1 << 30, batch_size, this->min_subgraph_size_, config.EnableTensorRtEngine(1 << 30,
precision, false, this->trt_calib_mode_); batch_size,
this->min_subgraph_size_,
precision,
false,
this->trt_calib_mode_);
// set use dynamic shape // set use dynamic shape
if (this->use_dynamic_shape_) { if (this->use_dynamic_shape_) {
...@@ -64,8 +69,8 @@ void ObjectDetector::LoadModel(const std::string &model_dir, ...@@ -64,8 +69,8 @@ void ObjectDetector::LoadModel(const std::string &model_dir,
const std::map<std::string, std::vector<int>> map_opt_input_shape = { const std::map<std::string, std::vector<int>> map_opt_input_shape = {
{"image", opt_input_shape}}; {"image", opt_input_shape}};
config.SetTRTDynamicShapeInfo(map_min_input_shape, map_max_input_shape, config.SetTRTDynamicShapeInfo(
map_opt_input_shape); map_min_input_shape, map_max_input_shape, map_opt_input_shape);
std::cout << "TensorRT dynamic shape enabled" << std::endl; std::cout << "TensorRT dynamic shape enabled" << std::endl;
} }
} }
...@@ -90,12 +95,15 @@ void ObjectDetector::LoadModel(const std::string &model_dir, ...@@ -90,12 +95,15 @@ void ObjectDetector::LoadModel(const std::string &model_dir,
} }
// Visualiztion MaskDetector results // Visualiztion MaskDetector results
cv::Mat cv::Mat VisualizeResult(
VisualizeResult(const cv::Mat &img, const cv::Mat &img,
const std::vector<PaddleDetection::ObjectResult> &results, const std::vector<PaddleDetection::ObjectResult> &results,
const std::vector<std::string> &lables, const std::vector<std::string> &lables,
const std::vector<int> &colormap, const bool is_rbox = false) { const std::vector<int> &colormap,
const bool is_rbox = false) {
cv::Mat vis_img = img.clone(); cv::Mat vis_img = img.clone();
int img_h = vis_img.rows;
int img_w = vis_img.cols;
for (int i = 0; i < results.size(); ++i) { for (int i = 0; i < results.size(); ++i) {
// Configure color and text size // Configure color and text size
std::ostringstream oss; std::ostringstream oss;
...@@ -129,19 +137,52 @@ VisualizeResult(const cv::Mat &img, ...@@ -129,19 +137,52 @@ VisualizeResult(const cv::Mat &img,
cv::Rect roi = cv::Rect(results[i].rect[0], results[i].rect[1], w, h); cv::Rect roi = cv::Rect(results[i].rect[0], results[i].rect[1], w, h);
// Draw roi object, text, and background // Draw roi object, text, and background
cv::rectangle(vis_img, roi, roi_color, 2); cv::rectangle(vis_img, roi, roi_color, 2);
// Draw mask
std::vector<int> mask_v = results[i].mask;
if (mask_v.size() > 0) {
cv::Mat mask = cv::Mat(img_h, img_w, CV_32S);
std::memcpy(mask.data, mask_v.data(), mask_v.size() * sizeof(int));
cv::Mat colored_img = vis_img.clone();
std::vector<cv::Mat> contours;
cv::Mat hierarchy;
mask.convertTo(mask, CV_8U);
cv::findContours(
mask, contours, hierarchy, cv::RETR_CCOMP, cv::CHAIN_APPROX_SIMPLE);
cv::drawContours(colored_img,
contours,
-1,
roi_color,
-1,
cv::LINE_8,
hierarchy,
100);
cv::Mat debug_roi = vis_img;
colored_img = 0.4 * colored_img + 0.6 * vis_img;
colored_img.copyTo(vis_img, mask);
}
} }
origin.x = results[i].rect[0]; origin.x = results[i].rect[0];
origin.y = results[i].rect[1]; origin.y = results[i].rect[1];
// Configure text background // Configure text background
cv::Rect text_back = cv::Rect text_back = cv::Rect(results[i].rect[0],
cv::Rect(results[i].rect[0], results[i].rect[1] - text_size.height, results[i].rect[1] - text_size.height,
text_size.width, text_size.height); text_size.width,
text_size.height);
// Draw text, and background // Draw text, and background
cv::rectangle(vis_img, text_back, roi_color, -1); cv::rectangle(vis_img, text_back, roi_color, -1);
cv::putText(vis_img, text, origin, font_face, font_scale, cv::putText(vis_img,
cv::Scalar(255, 255, 255), thickness); text,
origin,
font_face,
font_scale,
cv::Scalar(255, 255, 255),
thickness);
} }
return vis_img; return vis_img;
} }
...@@ -156,10 +197,18 @@ void ObjectDetector::Preprocess(const cv::Mat &ori_im) { ...@@ -156,10 +197,18 @@ void ObjectDetector::Preprocess(const cv::Mat &ori_im) {
void ObjectDetector::Postprocess( void ObjectDetector::Postprocess(
const std::vector<cv::Mat> mats, const std::vector<cv::Mat> mats,
std::vector<PaddleDetection::ObjectResult> *result, std::vector<PaddleDetection::ObjectResult> *result,
std::vector<int> bbox_num, std::vector<float> output_data_, std::vector<int> bbox_num,
std::vector<float> output_data_,
std::vector<int> output_mask_data_,
bool is_rbox = false) { bool is_rbox = false) {
result->clear(); result->clear();
int start_idx = 0; int start_idx = 0;
int total_num = std::accumulate(bbox_num.begin(), bbox_num.end(), 0);
int out_mask_dim = -1;
if (config_.mask_) {
out_mask_dim = output_mask_data_.size() / total_num;
}
for (int im_id = 0; im_id < mats.size(); im_id++) { for (int im_id = 0; im_id < mats.size(); im_id++) {
cv::Mat raw_mat = mats[im_id]; cv::Mat raw_mat = mats[im_id];
int rh = 1; int rh = 1;
...@@ -204,6 +253,17 @@ void ObjectDetector::Postprocess( ...@@ -204,6 +253,17 @@ void ObjectDetector::Postprocess(
result_item.rect = {xmin, ymin, xmax, ymax}; result_item.rect = {xmin, ymin, xmax, ymax};
result_item.class_id = class_id; result_item.class_id = class_id;
result_item.confidence = score; result_item.confidence = score;
if (config_.mask_) {
std::vector<int> mask;
for (int k = 0; k < out_mask_dim; ++k) {
if (output_mask_data_[k + j * out_mask_dim] > -1) {
mask.push_back(output_mask_data_[k + j * out_mask_dim]);
}
}
result_item.mask = mask;
}
result->push_back(result_item); result->push_back(result_item);
} }
} }
...@@ -212,7 +272,8 @@ void ObjectDetector::Postprocess( ...@@ -212,7 +272,8 @@ void ObjectDetector::Postprocess(
} }
void ObjectDetector::Predict(const std::vector<cv::Mat> imgs, void ObjectDetector::Predict(const std::vector<cv::Mat> imgs,
const double threshold, const int warmup, const double threshold,
const int warmup,
const int repeats, const int repeats,
std::vector<PaddleDetection::ObjectResult> *result, std::vector<PaddleDetection::ObjectResult> *result,
std::vector<int> *bbox_num, std::vector<int> *bbox_num,
...@@ -226,6 +287,7 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs, ...@@ -226,6 +287,7 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs,
std::vector<float> scale_factor_all(batch_size * 2); std::vector<float> scale_factor_all(batch_size * 2);
std::vector<const float *> output_data_list_; std::vector<const float *> output_data_list_;
std::vector<int> out_bbox_num_data_; std::vector<int> out_bbox_num_data_;
std::vector<int> out_mask_data_;
// in_net img for each batch // in_net img for each batch
std::vector<cv::Mat> in_net_img_all(batch_size); std::vector<cv::Mat> in_net_img_all(batch_size);
...@@ -240,8 +302,8 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs, ...@@ -240,8 +302,8 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs,
scale_factor_all[bs_idx * 2] = inputs_.scale_factor_[0]; scale_factor_all[bs_idx * 2] = inputs_.scale_factor_[0];
scale_factor_all[bs_idx * 2 + 1] = inputs_.scale_factor_[1]; scale_factor_all[bs_idx * 2 + 1] = inputs_.scale_factor_[1];
in_data_all.insert(in_data_all.end(), inputs_.im_data_.begin(), in_data_all.insert(
inputs_.im_data_.end()); in_data_all.end(), inputs_.im_data_.begin(), inputs_.im_data_.end());
// collect in_net img // collect in_net img
in_net_img_all[bs_idx] = inputs_.in_net_im_; in_net_img_all[bs_idx] = inputs_.in_net_im_;
...@@ -262,8 +324,8 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs, ...@@ -262,8 +324,8 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs,
pad_data.resize(rc * rh * rw); pad_data.resize(rc * rh * rw);
float *base = pad_data.data(); float *base = pad_data.data();
for (int i = 0; i < rc; ++i) { for (int i = 0; i < rc; ++i) {
cv::extractChannel(pad_img, cv::extractChannel(
cv::Mat(rh, rw, CV_32FC1, base + i * rh * rw), i); pad_img, cv::Mat(rh, rw, CV_32FC1, base + i * rh * rw), i);
} }
in_data_all.insert(in_data_all.end(), pad_data.begin(), pad_data.end()); in_data_all.insert(in_data_all.end(), pad_data.begin(), pad_data.end());
} }
...@@ -304,9 +366,12 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs, ...@@ -304,9 +366,12 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs,
for (int j = 0; j < output_names.size(); j++) { for (int j = 0; j < output_names.size(); j++) {
auto output_tensor = predictor_->GetOutputHandle(output_names[j]); auto output_tensor = predictor_->GetOutputHandle(output_names[j]);
std::vector<int> output_shape = output_tensor->shape(); std::vector<int> output_shape = output_tensor->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, int out_num = std::accumulate(
std::multiplies<int>()); output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
if (output_tensor->type() == paddle_infer::DataType::INT32) { if (config_.mask_ && (j == 2)) {
out_mask_data_.resize(out_num);
output_tensor->CopyToCpu(out_mask_data_.data());
} else if (output_tensor->type() == paddle_infer::DataType::INT32) {
out_bbox_num_data_.resize(out_num); out_bbox_num_data_.resize(out_num);
output_tensor->CopyToCpu(out_bbox_num_data_.data()); output_tensor->CopyToCpu(out_bbox_num_data_.data());
} else { } else {
...@@ -328,10 +393,13 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs, ...@@ -328,10 +393,13 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs,
for (int j = 0; j < output_names.size(); j++) { for (int j = 0; j < output_names.size(); j++) {
auto output_tensor = predictor_->GetOutputHandle(output_names[j]); auto output_tensor = predictor_->GetOutputHandle(output_names[j]);
std::vector<int> output_shape = output_tensor->shape(); std::vector<int> output_shape = output_tensor->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, int out_num = std::accumulate(
std::multiplies<int>()); output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
output_shape_list.push_back(output_shape); output_shape_list.push_back(output_shape);
if (output_tensor->type() == paddle_infer::DataType::INT32) { if (config_.mask_ && (j == 2)) {
out_mask_data_.resize(out_num);
output_tensor->CopyToCpu(out_mask_data_.data());
} else if (output_tensor->type() == paddle_infer::DataType::INT32) {
out_bbox_num_data_.resize(out_num); out_bbox_num_data_.resize(out_num);
output_tensor->CopyToCpu(out_bbox_num_data_.data()); output_tensor->CopyToCpu(out_bbox_num_data_.data());
} else { } else {
...@@ -356,18 +424,30 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs, ...@@ -356,18 +424,30 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs,
reg_max = output_shape_list[i][2] / 4 - 1; reg_max = output_shape_list[i][2] / 4 - 1;
} }
float *buffer = new float[out_tensor_list[i].size()]; float *buffer = new float[out_tensor_list[i].size()];
memcpy(buffer, &out_tensor_list[i][0], memcpy(buffer,
&out_tensor_list[i][0],
out_tensor_list[i].size() * sizeof(float)); out_tensor_list[i].size() * sizeof(float));
output_data_list_.push_back(buffer); output_data_list_.push_back(buffer);
} }
PaddleDetection::PicoDetPostProcess( PaddleDetection::PicoDetPostProcess(
result, output_data_list_, config_.fpn_stride_, inputs_.im_shape_, result,
inputs_.scale_factor_, config_.nms_info_["score_threshold"].as<float>(), output_data_list_,
config_.nms_info_["nms_threshold"].as<float>(), num_class, reg_max); config_.fpn_stride_,
inputs_.im_shape_,
inputs_.scale_factor_,
config_.nms_info_["score_threshold"].as<float>(),
config_.nms_info_["nms_threshold"].as<float>(),
num_class,
reg_max);
bbox_num->push_back(result->size()); bbox_num->push_back(result->size());
} else { } else {
is_rbox = output_shape_list[0][output_shape_list[0].size() - 1] % 10 == 0; is_rbox = output_shape_list[0][output_shape_list[0].size() - 1] % 10 == 0;
Postprocess(imgs, result, out_bbox_num_data_, out_tensor_list[0], is_rbox); Postprocess(imgs,
result,
out_bbox_num_data_,
out_tensor_list[0],
out_mask_data_,
is_rbox);
for (int k = 0; k < out_bbox_num_data_.size(); k++) { for (int k = 0; k < out_bbox_num_data_.size(); k++) {
int tmp = out_bbox_num_data_[k]; int tmp = out_bbox_num_data_[k];
bbox_num->push_back(tmp); bbox_num->push_back(tmp);
......
...@@ -60,12 +60,11 @@ void Permute::Run(cv::Mat* im, ImageBlob* data) { ...@@ -60,12 +60,11 @@ void Permute::Run(cv::Mat* im, ImageBlob* data) {
void Resize::Run(cv::Mat* im, ImageBlob* data) { void Resize::Run(cv::Mat* im, ImageBlob* data) {
auto resize_scale = GenerateScale(*im); auto resize_scale = GenerateScale(*im);
data->im_shape_ = {static_cast<float>(im->cols * resize_scale.first),
static_cast<float>(im->rows * resize_scale.second)};
data->in_net_shape_ = {static_cast<float>(im->cols * resize_scale.first),
static_cast<float>(im->rows * resize_scale.second)};
cv::resize( cv::resize(
*im, *im, cv::Size(), resize_scale.first, resize_scale.second, interp_); *im, *im, cv::Size(), resize_scale.first, resize_scale.second, interp_);
data->in_net_shape_ = {static_cast<float>(im->rows),
static_cast<float>(im->cols)};
data->im_shape_ = { data->im_shape_ = {
static_cast<float>(im->rows), static_cast<float>(im->cols), static_cast<float>(im->rows), static_cast<float>(im->cols),
}; };
...@@ -154,6 +153,7 @@ float LetterBoxResize::GenerateScale(const cv::Mat& im) { ...@@ -154,6 +153,7 @@ float LetterBoxResize::GenerateScale(const cv::Mat& im) {
void PadStride::Run(cv::Mat* im, ImageBlob* data) { void PadStride::Run(cv::Mat* im, ImageBlob* data) {
if (stride_ <= 0) { if (stride_ <= 0) {
data->in_net_im_ = im->clone();
return; return;
} }
int rc = im->channels(); int rc = im->channels();
...@@ -242,7 +242,9 @@ bool CheckDynamicInput(const std::vector<cv::Mat>& imgs) { ...@@ -242,7 +242,9 @@ bool CheckDynamicInput(const std::vector<cv::Mat>& imgs) {
int h = imgs.at(0).rows; int h = imgs.at(0).rows;
int w = imgs.at(0).cols; int w = imgs.at(0).cols;
for (int i = 1; i < imgs.size(); ++i) { for (int i = 1; i < imgs.size(); ++i) {
if (imgs.at(i).rows != h || imgs.at(i).cols != w) { int hi = imgs.at(i).rows;
int wi = imgs.at(i).cols;
if (hi != h || wi != w) {
return true; return true;
} }
} }
......
...@@ -206,6 +206,7 @@ class Detector(object): ...@@ -206,6 +206,7 @@ class Detector(object):
for k, v in res.items(): for k, v in res.items():
results[k].append(v) results[k].append(v)
for k, v in results.items(): for k, v in results.items():
if k != 'masks':
results[k] = np.concatenate(v) results[k] = np.concatenate(v)
return results return results
...@@ -296,7 +297,7 @@ class Detector(object): ...@@ -296,7 +297,7 @@ class Detector(object):
if not os.path.exists(self.output_dir): if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir) os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_out_name) out_path = os.path.join(self.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(*'mp4v') fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 1 index = 1
while (1): while (1):
......
...@@ -96,6 +96,8 @@ def draw_mask(im, np_boxes, np_masks, labels, threshold=0.5): ...@@ -96,6 +96,8 @@ def draw_mask(im, np_boxes, np_masks, labels, threshold=0.5):
expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1) expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
np_boxes = np_boxes[expect_boxes, :] np_boxes = np_boxes[expect_boxes, :]
np_masks = np_masks[expect_boxes, :, :] np_masks = np_masks[expect_boxes, :, :]
im_h, im_w = im.shape[:2]
np_masks = np_masks[:, :im_h, :im_w]
for i in range(len(np_masks)): for i in range(len(np_masks)):
clsid, score = int(np_boxes[i][0]), np_boxes[i][1] clsid, score = int(np_boxes[i][0]), np_boxes[i][1]
mask = np_masks[i] mask = np_masks[i]
......
...@@ -65,6 +65,14 @@ def get_det_poly_res(bboxes, bbox_nums, image_id, label_to_cat_id_map, bias=0): ...@@ -65,6 +65,14 @@ def get_det_poly_res(bboxes, bbox_nums, image_id, label_to_cat_id_map, bias=0):
return det_res return det_res
def strip_mask(mask):
row = mask[0, 0, :]
col = mask[0, :, 0]
im_h = len(col) - np.count_nonzero(col == -1)
im_w = len(row) - np.count_nonzero(row == -1)
return mask[:, :im_h, :im_w]
def get_seg_res(masks, bboxes, mask_nums, image_id, label_to_cat_id_map): def get_seg_res(masks, bboxes, mask_nums, image_id, label_to_cat_id_map):
import pycocotools.mask as mask_util import pycocotools.mask as mask_util
seg_res = [] seg_res = []
...@@ -72,8 +80,10 @@ def get_seg_res(masks, bboxes, mask_nums, image_id, label_to_cat_id_map): ...@@ -72,8 +80,10 @@ def get_seg_res(masks, bboxes, mask_nums, image_id, label_to_cat_id_map):
for i in range(len(mask_nums)): for i in range(len(mask_nums)):
cur_image_id = int(image_id[i][0]) cur_image_id = int(image_id[i][0])
det_nums = mask_nums[i] det_nums = mask_nums[i]
mask_i = masks[k:k + det_nums]
mask_i = strip_mask(mask_i)
for j in range(det_nums): for j in range(det_nums):
mask = masks[k].astype(np.uint8) mask = mask_i[j].astype(np.uint8)
score = float(bboxes[k][1]) score = float(bboxes[k][1])
label = int(bboxes[k][0]) label = int(bboxes[k][0])
k = k + 1 k = k + 1
......
...@@ -111,8 +111,8 @@ class CascadeRCNN(BaseArch): ...@@ -111,8 +111,8 @@ class CascadeRCNN(BaseArch):
bbox, bbox_num = self.bbox_post_process( bbox, bbox_num = self.bbox_post_process(
preds, (refined_rois, rois_num), im_shape, scale_factor) preds, (refined_rois, rois_num), im_shape, scale_factor)
# rescale the prediction back to origin image # rescale the prediction back to origin image
bbox_pred = self.bbox_post_process.get_pred(bbox, bbox_num, bbox, bbox_pred, bbox_num = self.bbox_post_process.get_pred(
im_shape, scale_factor) bbox, bbox_num, im_shape, scale_factor)
if not self.with_mask: if not self.with_mask:
return bbox_pred, bbox_num, None return bbox_pred, bbox_num, None
mask_out = self.mask_head(body_feats, bbox, bbox_num, self.inputs) mask_out = self.mask_head(body_feats, bbox, bbox_num, self.inputs)
......
...@@ -112,8 +112,8 @@ class MaskRCNN(BaseArch): ...@@ -112,8 +112,8 @@ class MaskRCNN(BaseArch):
body_feats, bbox, bbox_num, self.inputs, feat_func=feat_func) body_feats, bbox, bbox_num, self.inputs, feat_func=feat_func)
# rescale the prediction back to origin image # rescale the prediction back to origin image
bbox_pred = self.bbox_post_process.get_pred(bbox, bbox_num, bbox, bbox_pred, bbox_num = self.bbox_post_process.get_pred(
im_shape, scale_factor) bbox, bbox_num, im_shape, scale_factor)
origin_shape = self.bbox_post_process.get_origin_shape() origin_shape = self.bbox_post_process.get_origin_shape()
mask_pred = self.mask_post_process(mask_out, bbox_pred, bbox_num, mask_pred = self.mask_post_process(mask_out, bbox_pred, bbox_num,
origin_shape) origin_shape)
......
...@@ -171,7 +171,7 @@ class BBoxPostProcess(nn.Layer): ...@@ -171,7 +171,7 @@ class BBoxPostProcess(nn.Layer):
pred_label = paddle.where(keep_mask, pred_label, pred_label = paddle.where(keep_mask, pred_label,
paddle.ones_like(pred_label) * -1) paddle.ones_like(pred_label) * -1)
pred_result = paddle.concat([pred_label, pred_score, pred_bbox], axis=1) pred_result = paddle.concat([pred_label, pred_score, pred_bbox], axis=1)
return pred_result return bboxes, pred_result, bbox_num
def get_origin_shape(self, ): def get_origin_shape(self, ):
return self.origin_shape_list return self.origin_shape_list
...@@ -179,6 +179,7 @@ class BBoxPostProcess(nn.Layer): ...@@ -179,6 +179,7 @@ class BBoxPostProcess(nn.Layer):
@register @register
class MaskPostProcess(object): class MaskPostProcess(object):
__shared__ = ['export_onnx']
""" """
refer to: refer to:
https://github.com/facebookresearch/detectron2/layers/mask_ops.py https://github.com/facebookresearch/detectron2/layers/mask_ops.py
...@@ -186,9 +187,10 @@ class MaskPostProcess(object): ...@@ -186,9 +187,10 @@ class MaskPostProcess(object):
Get Mask output according to the output from model Get Mask output according to the output from model
""" """
def __init__(self, binary_thresh=0.5): def __init__(self, binary_thresh=0.5, export_onnx=False):
super(MaskPostProcess, self).__init__() super(MaskPostProcess, self).__init__()
self.binary_thresh = binary_thresh self.binary_thresh = binary_thresh
self.export_onnx = export_onnx
def paste_mask(self, masks, boxes, im_h, im_w): def paste_mask(self, masks, boxes, im_h, im_w):
""" """
...@@ -200,6 +202,7 @@ class MaskPostProcess(object): ...@@ -200,6 +202,7 @@ class MaskPostProcess(object):
N = masks.shape[0] N = masks.shape[0]
img_y = paddle.arange(y0_int, y1_int) + 0.5 img_y = paddle.arange(y0_int, y1_int) + 0.5
img_x = paddle.arange(x0_int, x1_int) + 0.5 img_x = paddle.arange(x0_int, x1_int) + 0.5
img_y = (img_y - y0) / (y1 - y0) * 2 - 1 img_y = (img_y - y0) / (y1 - y0) * 2 - 1
img_x = (img_x - x0) / (x1 - x0) * 2 - 1 img_x = (img_x - x0) / (x1 - x0) * 2 - 1
# img_x, img_y have shapes (N, w), (N, h) # img_x, img_y have shapes (N, w), (N, h)
...@@ -230,15 +233,34 @@ class MaskPostProcess(object): ...@@ -230,15 +233,34 @@ class MaskPostProcess(object):
""" """
num_mask = mask_out.shape[0] num_mask = mask_out.shape[0]
origin_shape = paddle.cast(origin_shape, 'int32') origin_shape = paddle.cast(origin_shape, 'int32')
# TODO: support bs > 1 and mask output dtype is bool
if self.export_onnx:
h, w = origin_shape[0][0], origin_shape[0][1]
mask_onnx = self.paste_mask(mask_out[:, None, :, :], bboxes[:, 2:],
h, w)
mask_onnx = mask_onnx >= self.binary_thresh
pred_result = paddle.cast(mask_onnx, 'int32')
else:
max_h = paddle.max(origin_shape[:, 0])
max_w = paddle.max(origin_shape[:, 1])
pred_result = paddle.zeros( pred_result = paddle.zeros(
[num_mask, origin_shape[0][0], origin_shape[0][1]], dtype='int32') [num_mask, max_h, max_w], dtype='int32') - 1
im_h, im_w = origin_shape[0][0], origin_shape[0][1] id_start = 0
pred_mask = self.paste_mask(mask_out[:, None, :, :], bboxes[:, 2:], for i in range(paddle.shape(bbox_num)[0]):
im_h, im_w) bboxes_i = bboxes[id_start:id_start + bbox_num[i], :]
pred_mask = pred_mask >= self.binary_thresh mask_out_i = mask_out[id_start:id_start + bbox_num[i], :, :]
pred_result = paddle.cast(pred_mask, 'int32') im_h = origin_shape[i, 0]
im_w = origin_shape[i, 1]
bbox_num_i = bbox_num[id_start]
pred_mask = self.paste_mask(mask_out_i[:, None, :, :],
bboxes_i[:, 2:], im_h, im_w)
pred_mask = paddle.cast(pred_mask >= self.binary_thresh,
'int32')
pred_result[id_start:id_start + bbox_num[i], :im_h, :
im_w] = pred_mask
id_start += bbox_num[i]
return pred_result return pred_result
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册