From 020e1cd619f61a3ba800510fe3203d7bd2cb41dc Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Sat, 2 Apr 2022 23:03:03 +0800 Subject: [PATCH] support mask rcnn bs>2 (#5470) * support mask rcnn bs>2 * fix resize op in c++ deploy * fix preprocess in mask rcnn * fix cascade mask rcnn when box is empty * update merge batch --- deploy/cpp/include/config_parser.h | 5 + deploy/cpp/include/object_detector.h | 1 + deploy/cpp/include/utils.h | 14 +- deploy/cpp/src/object_detector.cc | 146 ++++++++++++++----- deploy/cpp/src/preprocess_op.cc | 12 +- deploy/python/infer.py | 5 +- deploy/python/visualize.py | 2 + ppdet/metrics/json_results.py | 12 +- ppdet/modeling/architectures/cascade_rcnn.py | 4 +- ppdet/modeling/architectures/mask_rcnn.py | 4 +- ppdet/modeling/post_process.py | 44 ++++-- 11 files changed, 187 insertions(+), 62 deletions(-) diff --git a/deploy/cpp/include/config_parser.h b/deploy/cpp/include/config_parser.h index 82d103723..1f2e381c5 100644 --- a/deploy/cpp/include/config_parser.h +++ b/deploy/cpp/include/config_parser.h @@ -120,6 +120,10 @@ class ConfigPaser { } } + if (config["mask"].IsDefined()) { + mask_ = config["mask"].as(); + } + return true; } std::string mode_; @@ -132,6 +136,7 @@ class ConfigPaser { std::vector fpn_stride_; bool use_dynamic_shape_; float conf_thresh_; + bool mask_ = false; }; } // namespace PaddleDetection diff --git a/deploy/cpp/include/object_detector.h b/deploy/cpp/include/object_detector.h index 0a336c334..30dd09ab7 100644 --- a/deploy/cpp/include/object_detector.h +++ b/deploy/cpp/include/object_detector.h @@ -114,6 +114,7 @@ class ObjectDetector { std::vector* result, std::vector bbox_num, std::vector output_data_, + std::vector output_mask_data_, bool is_rbox); std::shared_ptr predictor_; diff --git a/deploy/cpp/include/utils.h b/deploy/cpp/include/utils.h index 3802e1267..b41db0dac 100644 --- a/deploy/cpp/include/utils.h +++ b/deploy/cpp/include/utils.h @@ -14,13 +14,13 @@ #pragma once -#include -#include -#include -#include +#include #include +#include #include -#include +#include +#include +#include namespace PaddleDetection { @@ -32,8 +32,10 @@ struct ObjectResult { int class_id; // Confidence of detected object float confidence; + // Mask of detected object + std::vector mask; }; void nms(std::vector &input_boxes, float nms_threshold); -} // namespace PaddleDetection \ No newline at end of file +} // namespace PaddleDetection diff --git a/deploy/cpp/src/object_detector.cc b/deploy/cpp/src/object_detector.cc index e455c90aa..8d72408bf 100644 --- a/deploy/cpp/src/object_detector.cc +++ b/deploy/cpp/src/object_detector.cc @@ -41,12 +41,17 @@ void ObjectDetector::LoadModel(const std::string &model_dir, } else if (run_mode == "trt_int8") { precision = paddle_infer::Config::Precision::kInt8; } else { - printf("run_mode should be 'paddle', 'trt_fp32', 'trt_fp16' or " - "'trt_int8'"); + printf( + "run_mode should be 'paddle', 'trt_fp32', 'trt_fp16' or " + "'trt_int8'"); } // set tensorrt - config.EnableTensorRtEngine(1 << 30, batch_size, this->min_subgraph_size_, - precision, false, this->trt_calib_mode_); + config.EnableTensorRtEngine(1 << 30, + batch_size, + this->min_subgraph_size_, + precision, + false, + this->trt_calib_mode_); // set use dynamic shape if (this->use_dynamic_shape_) { @@ -64,8 +69,8 @@ void ObjectDetector::LoadModel(const std::string &model_dir, const std::map> map_opt_input_shape = { {"image", opt_input_shape}}; - config.SetTRTDynamicShapeInfo(map_min_input_shape, map_max_input_shape, - map_opt_input_shape); + config.SetTRTDynamicShapeInfo( + map_min_input_shape, map_max_input_shape, map_opt_input_shape); std::cout << "TensorRT dynamic shape enabled" << std::endl; } } @@ -90,12 +95,15 @@ void ObjectDetector::LoadModel(const std::string &model_dir, } // Visualiztion MaskDetector results -cv::Mat -VisualizeResult(const cv::Mat &img, - const std::vector &results, - const std::vector &lables, - const std::vector &colormap, const bool is_rbox = false) { +cv::Mat VisualizeResult( + const cv::Mat &img, + const std::vector &results, + const std::vector &lables, + const std::vector &colormap, + const bool is_rbox = false) { cv::Mat vis_img = img.clone(); + int img_h = vis_img.rows; + int img_w = vis_img.cols; for (int i = 0; i < results.size(); ++i) { // Configure color and text size std::ostringstream oss; @@ -129,19 +137,52 @@ VisualizeResult(const cv::Mat &img, cv::Rect roi = cv::Rect(results[i].rect[0], results[i].rect[1], w, h); // Draw roi object, text, and background cv::rectangle(vis_img, roi, roi_color, 2); + + // Draw mask + std::vector mask_v = results[i].mask; + if (mask_v.size() > 0) { + cv::Mat mask = cv::Mat(img_h, img_w, CV_32S); + std::memcpy(mask.data, mask_v.data(), mask_v.size() * sizeof(int)); + + cv::Mat colored_img = vis_img.clone(); + + std::vector contours; + cv::Mat hierarchy; + mask.convertTo(mask, CV_8U); + cv::findContours( + mask, contours, hierarchy, cv::RETR_CCOMP, cv::CHAIN_APPROX_SIMPLE); + cv::drawContours(colored_img, + contours, + -1, + roi_color, + -1, + cv::LINE_8, + hierarchy, + 100); + + cv::Mat debug_roi = vis_img; + colored_img = 0.4 * colored_img + 0.6 * vis_img; + colored_img.copyTo(vis_img, mask); + } } origin.x = results[i].rect[0]; origin.y = results[i].rect[1]; // Configure text background - cv::Rect text_back = - cv::Rect(results[i].rect[0], results[i].rect[1] - text_size.height, - text_size.width, text_size.height); + cv::Rect text_back = cv::Rect(results[i].rect[0], + results[i].rect[1] - text_size.height, + text_size.width, + text_size.height); // Draw text, and background cv::rectangle(vis_img, text_back, roi_color, -1); - cv::putText(vis_img, text, origin, font_face, font_scale, - cv::Scalar(255, 255, 255), thickness); + cv::putText(vis_img, + text, + origin, + font_face, + font_scale, + cv::Scalar(255, 255, 255), + thickness); } return vis_img; } @@ -156,10 +197,18 @@ void ObjectDetector::Preprocess(const cv::Mat &ori_im) { void ObjectDetector::Postprocess( const std::vector mats, std::vector *result, - std::vector bbox_num, std::vector output_data_, + std::vector bbox_num, + std::vector output_data_, + std::vector output_mask_data_, bool is_rbox = false) { result->clear(); int start_idx = 0; + int total_num = std::accumulate(bbox_num.begin(), bbox_num.end(), 0); + int out_mask_dim = -1; + if (config_.mask_) { + out_mask_dim = output_mask_data_.size() / total_num; + } + for (int im_id = 0; im_id < mats.size(); im_id++) { cv::Mat raw_mat = mats[im_id]; int rh = 1; @@ -204,6 +253,17 @@ void ObjectDetector::Postprocess( result_item.rect = {xmin, ymin, xmax, ymax}; result_item.class_id = class_id; result_item.confidence = score; + + if (config_.mask_) { + std::vector mask; + for (int k = 0; k < out_mask_dim; ++k) { + if (output_mask_data_[k + j * out_mask_dim] > -1) { + mask.push_back(output_mask_data_[k + j * out_mask_dim]); + } + } + result_item.mask = mask; + } + result->push_back(result_item); } } @@ -212,7 +272,8 @@ void ObjectDetector::Postprocess( } void ObjectDetector::Predict(const std::vector imgs, - const double threshold, const int warmup, + const double threshold, + const int warmup, const int repeats, std::vector *result, std::vector *bbox_num, @@ -226,6 +287,7 @@ void ObjectDetector::Predict(const std::vector imgs, std::vector scale_factor_all(batch_size * 2); std::vector output_data_list_; std::vector out_bbox_num_data_; + std::vector out_mask_data_; // in_net img for each batch std::vector in_net_img_all(batch_size); @@ -240,8 +302,8 @@ void ObjectDetector::Predict(const std::vector imgs, scale_factor_all[bs_idx * 2] = inputs_.scale_factor_[0]; scale_factor_all[bs_idx * 2 + 1] = inputs_.scale_factor_[1]; - in_data_all.insert(in_data_all.end(), inputs_.im_data_.begin(), - inputs_.im_data_.end()); + in_data_all.insert( + in_data_all.end(), inputs_.im_data_.begin(), inputs_.im_data_.end()); // collect in_net img in_net_img_all[bs_idx] = inputs_.in_net_im_; @@ -262,8 +324,8 @@ void ObjectDetector::Predict(const std::vector imgs, pad_data.resize(rc * rh * rw); float *base = pad_data.data(); for (int i = 0; i < rc; ++i) { - cv::extractChannel(pad_img, - cv::Mat(rh, rw, CV_32FC1, base + i * rh * rw), i); + cv::extractChannel( + pad_img, cv::Mat(rh, rw, CV_32FC1, base + i * rh * rw), i); } in_data_all.insert(in_data_all.end(), pad_data.begin(), pad_data.end()); } @@ -304,9 +366,12 @@ void ObjectDetector::Predict(const std::vector imgs, for (int j = 0; j < output_names.size(); j++) { auto output_tensor = predictor_->GetOutputHandle(output_names[j]); std::vector output_shape = output_tensor->shape(); - int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, - std::multiplies()); - if (output_tensor->type() == paddle_infer::DataType::INT32) { + int out_num = std::accumulate( + output_shape.begin(), output_shape.end(), 1, std::multiplies()); + if (config_.mask_ && (j == 2)) { + out_mask_data_.resize(out_num); + output_tensor->CopyToCpu(out_mask_data_.data()); + } else if (output_tensor->type() == paddle_infer::DataType::INT32) { out_bbox_num_data_.resize(out_num); output_tensor->CopyToCpu(out_bbox_num_data_.data()); } else { @@ -328,10 +393,13 @@ void ObjectDetector::Predict(const std::vector imgs, for (int j = 0; j < output_names.size(); j++) { auto output_tensor = predictor_->GetOutputHandle(output_names[j]); std::vector output_shape = output_tensor->shape(); - int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, - std::multiplies()); + int out_num = std::accumulate( + output_shape.begin(), output_shape.end(), 1, std::multiplies()); output_shape_list.push_back(output_shape); - if (output_tensor->type() == paddle_infer::DataType::INT32) { + if (config_.mask_ && (j == 2)) { + out_mask_data_.resize(out_num); + output_tensor->CopyToCpu(out_mask_data_.data()); + } else if (output_tensor->type() == paddle_infer::DataType::INT32) { out_bbox_num_data_.resize(out_num); output_tensor->CopyToCpu(out_bbox_num_data_.data()); } else { @@ -356,18 +424,30 @@ void ObjectDetector::Predict(const std::vector imgs, reg_max = output_shape_list[i][2] / 4 - 1; } float *buffer = new float[out_tensor_list[i].size()]; - memcpy(buffer, &out_tensor_list[i][0], + memcpy(buffer, + &out_tensor_list[i][0], out_tensor_list[i].size() * sizeof(float)); output_data_list_.push_back(buffer); } PaddleDetection::PicoDetPostProcess( - result, output_data_list_, config_.fpn_stride_, inputs_.im_shape_, - inputs_.scale_factor_, config_.nms_info_["score_threshold"].as(), - config_.nms_info_["nms_threshold"].as(), num_class, reg_max); + result, + output_data_list_, + config_.fpn_stride_, + inputs_.im_shape_, + inputs_.scale_factor_, + config_.nms_info_["score_threshold"].as(), + config_.nms_info_["nms_threshold"].as(), + num_class, + reg_max); bbox_num->push_back(result->size()); } else { is_rbox = output_shape_list[0][output_shape_list[0].size() - 1] % 10 == 0; - Postprocess(imgs, result, out_bbox_num_data_, out_tensor_list[0], is_rbox); + Postprocess(imgs, + result, + out_bbox_num_data_, + out_tensor_list[0], + out_mask_data_, + is_rbox); for (int k = 0; k < out_bbox_num_data_.size(); k++) { int tmp = out_bbox_num_data_[k]; bbox_num->push_back(tmp); diff --git a/deploy/cpp/src/preprocess_op.cc b/deploy/cpp/src/preprocess_op.cc index 4ac3daa30..d4a1fb419 100644 --- a/deploy/cpp/src/preprocess_op.cc +++ b/deploy/cpp/src/preprocess_op.cc @@ -60,12 +60,11 @@ void Permute::Run(cv::Mat* im, ImageBlob* data) { void Resize::Run(cv::Mat* im, ImageBlob* data) { auto resize_scale = GenerateScale(*im); - data->im_shape_ = {static_cast(im->cols * resize_scale.first), - static_cast(im->rows * resize_scale.second)}; - data->in_net_shape_ = {static_cast(im->cols * resize_scale.first), - static_cast(im->rows * resize_scale.second)}; cv::resize( *im, *im, cv::Size(), resize_scale.first, resize_scale.second, interp_); + + data->in_net_shape_ = {static_cast(im->rows), + static_cast(im->cols)}; data->im_shape_ = { static_cast(im->rows), static_cast(im->cols), }; @@ -154,6 +153,7 @@ float LetterBoxResize::GenerateScale(const cv::Mat& im) { void PadStride::Run(cv::Mat* im, ImageBlob* data) { if (stride_ <= 0) { + data->in_net_im_ = im->clone(); return; } int rc = im->channels(); @@ -242,7 +242,9 @@ bool CheckDynamicInput(const std::vector& imgs) { int h = imgs.at(0).rows; int w = imgs.at(0).cols; for (int i = 1; i < imgs.size(); ++i) { - if (imgs.at(i).rows != h || imgs.at(i).cols != w) { + int hi = imgs.at(i).rows; + int wi = imgs.at(i).cols; + if (hi != h || wi != w) { return true; } } diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 3296e16e5..2808d2c58 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -206,7 +206,8 @@ class Detector(object): for k, v in res.items(): results[k].append(v) for k, v in results.items(): - results[k] = np.concatenate(v) + if k != 'masks': + results[k] = np.concatenate(v) return results def get_timer(self): @@ -296,7 +297,7 @@ class Detector(object): if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) out_path = os.path.join(self.output_dir, video_out_name) - fourcc = cv2.VideoWriter_fourcc(*'mp4v') + fourcc = cv2.VideoWriter_fourcc(* 'mp4v') writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) index = 1 while (1): diff --git a/deploy/python/visualize.py b/deploy/python/visualize.py index 9c07b8491..c26a6e467 100644 --- a/deploy/python/visualize.py +++ b/deploy/python/visualize.py @@ -96,6 +96,8 @@ def draw_mask(im, np_boxes, np_masks, labels, threshold=0.5): expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1) np_boxes = np_boxes[expect_boxes, :] np_masks = np_masks[expect_boxes, :, :] + im_h, im_w = im.shape[:2] + np_masks = np_masks[:, :im_h, :im_w] for i in range(len(np_masks)): clsid, score = int(np_boxes[i][0]), np_boxes[i][1] mask = np_masks[i] diff --git a/ppdet/metrics/json_results.py b/ppdet/metrics/json_results.py index c703de63b..93354ec1f 100755 --- a/ppdet/metrics/json_results.py +++ b/ppdet/metrics/json_results.py @@ -65,6 +65,14 @@ def get_det_poly_res(bboxes, bbox_nums, image_id, label_to_cat_id_map, bias=0): return det_res +def strip_mask(mask): + row = mask[0, 0, :] + col = mask[0, :, 0] + im_h = len(col) - np.count_nonzero(col == -1) + im_w = len(row) - np.count_nonzero(row == -1) + return mask[:, :im_h, :im_w] + + def get_seg_res(masks, bboxes, mask_nums, image_id, label_to_cat_id_map): import pycocotools.mask as mask_util seg_res = [] @@ -72,8 +80,10 @@ def get_seg_res(masks, bboxes, mask_nums, image_id, label_to_cat_id_map): for i in range(len(mask_nums)): cur_image_id = int(image_id[i][0]) det_nums = mask_nums[i] + mask_i = masks[k:k + det_nums] + mask_i = strip_mask(mask_i) for j in range(det_nums): - mask = masks[k].astype(np.uint8) + mask = mask_i[j].astype(np.uint8) score = float(bboxes[k][1]) label = int(bboxes[k][0]) k = k + 1 diff --git a/ppdet/modeling/architectures/cascade_rcnn.py b/ppdet/modeling/architectures/cascade_rcnn.py index 4b5caa7a3..fc5949af0 100644 --- a/ppdet/modeling/architectures/cascade_rcnn.py +++ b/ppdet/modeling/architectures/cascade_rcnn.py @@ -111,8 +111,8 @@ class CascadeRCNN(BaseArch): bbox, bbox_num = self.bbox_post_process( preds, (refined_rois, rois_num), im_shape, scale_factor) # rescale the prediction back to origin image - bbox_pred = self.bbox_post_process.get_pred(bbox, bbox_num, - im_shape, scale_factor) + bbox, bbox_pred, bbox_num = self.bbox_post_process.get_pred( + bbox, bbox_num, im_shape, scale_factor) if not self.with_mask: return bbox_pred, bbox_num, None mask_out = self.mask_head(body_feats, bbox, bbox_num, self.inputs) diff --git a/ppdet/modeling/architectures/mask_rcnn.py b/ppdet/modeling/architectures/mask_rcnn.py index 43b8bff94..a322f9f8e 100644 --- a/ppdet/modeling/architectures/mask_rcnn.py +++ b/ppdet/modeling/architectures/mask_rcnn.py @@ -112,8 +112,8 @@ class MaskRCNN(BaseArch): body_feats, bbox, bbox_num, self.inputs, feat_func=feat_func) # rescale the prediction back to origin image - bbox_pred = self.bbox_post_process.get_pred(bbox, bbox_num, - im_shape, scale_factor) + bbox, bbox_pred, bbox_num = self.bbox_post_process.get_pred( + bbox, bbox_num, im_shape, scale_factor) origin_shape = self.bbox_post_process.get_origin_shape() mask_pred = self.mask_post_process(mask_out, bbox_pred, bbox_num, origin_shape) diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index 72e409e40..e74095505 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -171,7 +171,7 @@ class BBoxPostProcess(nn.Layer): pred_label = paddle.where(keep_mask, pred_label, paddle.ones_like(pred_label) * -1) pred_result = paddle.concat([pred_label, pred_score, pred_bbox], axis=1) - return pred_result + return bboxes, pred_result, bbox_num def get_origin_shape(self, ): return self.origin_shape_list @@ -179,6 +179,7 @@ class BBoxPostProcess(nn.Layer): @register class MaskPostProcess(object): + __shared__ = ['export_onnx'] """ refer to: https://github.com/facebookresearch/detectron2/layers/mask_ops.py @@ -186,9 +187,10 @@ class MaskPostProcess(object): Get Mask output according to the output from model """ - def __init__(self, binary_thresh=0.5): + def __init__(self, binary_thresh=0.5, export_onnx=False): super(MaskPostProcess, self).__init__() self.binary_thresh = binary_thresh + self.export_onnx = export_onnx def paste_mask(self, masks, boxes, im_h, im_w): """ @@ -200,6 +202,7 @@ class MaskPostProcess(object): N = masks.shape[0] img_y = paddle.arange(y0_int, y1_int) + 0.5 img_x = paddle.arange(x0_int, x1_int) + 0.5 + img_y = (img_y - y0) / (y1 - y0) * 2 - 1 img_x = (img_x - x0) / (x1 - x0) * 2 - 1 # img_x, img_y have shapes (N, w), (N, h) @@ -230,15 +233,34 @@ class MaskPostProcess(object): """ num_mask = mask_out.shape[0] origin_shape = paddle.cast(origin_shape, 'int32') - # TODO: support bs > 1 and mask output dtype is bool - pred_result = paddle.zeros( - [num_mask, origin_shape[0][0], origin_shape[0][1]], dtype='int32') - - im_h, im_w = origin_shape[0][0], origin_shape[0][1] - pred_mask = self.paste_mask(mask_out[:, None, :, :], bboxes[:, 2:], - im_h, im_w) - pred_mask = pred_mask >= self.binary_thresh - pred_result = paddle.cast(pred_mask, 'int32') + + if self.export_onnx: + h, w = origin_shape[0][0], origin_shape[0][1] + mask_onnx = self.paste_mask(mask_out[:, None, :, :], bboxes[:, 2:], + h, w) + mask_onnx = mask_onnx >= self.binary_thresh + pred_result = paddle.cast(mask_onnx, 'int32') + + else: + max_h = paddle.max(origin_shape[:, 0]) + max_w = paddle.max(origin_shape[:, 1]) + pred_result = paddle.zeros( + [num_mask, max_h, max_w], dtype='int32') - 1 + + id_start = 0 + for i in range(paddle.shape(bbox_num)[0]): + bboxes_i = bboxes[id_start:id_start + bbox_num[i], :] + mask_out_i = mask_out[id_start:id_start + bbox_num[i], :, :] + im_h = origin_shape[i, 0] + im_w = origin_shape[i, 1] + bbox_num_i = bbox_num[id_start] + pred_mask = self.paste_mask(mask_out_i[:, None, :, :], + bboxes_i[:, 2:], im_h, im_w) + pred_mask = paddle.cast(pred_mask >= self.binary_thresh, + 'int32') + pred_result[id_start:id_start + bbox_num[i], :im_h, : + im_w] = pred_mask + id_start += bbox_num[i] return pred_result -- GitLab