未验证 提交 0947171b 编写于 作者: C cnn 提交者: GitHub

[dev] s2anet fix export and deploy (#2919)

* s2anet fix export and deploy

* remove redundant code

* fix cpp for s2anet deploy

* fix bug of get_categories

* rename poly_to_rbox to poly2rbox

* add some comment for function
上级 912833f2
...@@ -9,7 +9,7 @@ TrainReader: ...@@ -9,7 +9,7 @@ TrainReader:
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {} - Permute: {}
batch_transforms: batch_transforms:
- RboxPadBatch: {pad_to_stride: 32, pad_gt: true} - PadBatch: {pad_to_stride: 32}
batch_size: 1 batch_size: 1
shuffle: true shuffle: true
drop_last: true drop_last: true
...@@ -22,7 +22,7 @@ EvalReader: ...@@ -22,7 +22,7 @@ EvalReader:
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {} - Permute: {}
batch_transforms: batch_transforms:
- RboxPadBatch: {pad_to_stride: 32, pad_gt: false} - PadBatch: {pad_to_stride: 32}
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
...@@ -36,7 +36,7 @@ TestReader: ...@@ -36,7 +36,7 @@ TestReader:
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {} - Permute: {}
batch_transforms: batch_transforms:
- RboxPadBatch: {pad_to_stride: 32, pad_gt: false} - PadBatch: {pad_to_stride: 32}
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
...@@ -27,7 +27,7 @@ parent_path = osp.abspath(osp.join(__file__, *(['..'] * 3))) ...@@ -27,7 +27,7 @@ parent_path = osp.abspath(osp.join(__file__, *(['..'] * 3)))
if parent_path not in sys.path: if parent_path not in sys.path:
sys.path.append(parent_path) sys.path.append(parent_path)
from ppdet.modeling.bbox_utils import poly_to_rbox from ppdet.modeling.bbox_utils import poly2rbox
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__) logger = setup_logger(__name__)
...@@ -118,7 +118,7 @@ def dota_2_coco(image_dir, ...@@ -118,7 +118,7 @@ def dota_2_coco(image_dir,
# rbox or bbox # rbox or bbox
if is_obb: if is_obb:
polys = [single_obj_poly] polys = [single_obj_poly]
rboxs = poly_to_rbox(polys) rboxs = poly2rbox(polys)
rbox = rboxs[0].tolist() rbox = rboxs[0].tolist()
single_obj['bbox'] = rbox single_obj['bbox'] = rbox
single_obj['area'] = rbox[2] * rbox[3] single_obj['area'] = rbox[2] * rbox[3]
......
...@@ -51,7 +51,8 @@ std::vector<int> GenerateColorMap(int num_class); ...@@ -51,7 +51,8 @@ std::vector<int> GenerateColorMap(int num_class);
cv::Mat VisualizeResult(const cv::Mat& img, cv::Mat VisualizeResult(const cv::Mat& img,
const std::vector<ObjectResult>& results, const std::vector<ObjectResult>& results,
const std::vector<std::string>& lable_list, const std::vector<std::string>& lable_list,
const std::vector<int>& colormap); const std::vector<int>& colormap,
const bool is_rbox);
class ObjectDetector { class ObjectDetector {
...@@ -120,7 +121,8 @@ class ObjectDetector { ...@@ -120,7 +121,8 @@ class ObjectDetector {
// Postprocess result // Postprocess result
void Postprocess( void Postprocess(
const cv::Mat& raw_mat, const cv::Mat& raw_mat,
std::vector<ObjectResult>* result); std::vector<ObjectResult>* result,
bool is_rbox);
std::shared_ptr<Predictor> predictor_; std::shared_ptr<Predictor> predictor_;
Preprocessor preprocessor_; Preprocessor preprocessor_;
......
...@@ -195,16 +195,30 @@ void PredictVideo(const std::string& video_path, ...@@ -195,16 +195,30 @@ void PredictVideo(const std::string& video_path,
// Capture all frames and do inference // Capture all frames and do inference
cv::Mat frame; cv::Mat frame;
int frame_id = 0; int frame_id = 0;
bool is_rbox = false;
while (capture.read(frame)) { while (capture.read(frame)) {
if (frame.empty()) { if (frame.empty()) {
break; break;
} }
det->Predict(frame, 0.5, 0, 1, &result, &det_times); det->Predict(frame, 0.5, 0, 1, &result, &det_times);
cv::Mat out_im = PaddleDetection::VisualizeResult(
frame, result, labels, colormap);
for (const auto& item : result) { for (const auto& item : result) {
printf("In frame id %d, we detect: class=%d confidence=%.2f rect=[%d %d %d %d]\n", if (item.rect.size() > 6){
frame_id, is_rbox = true;
printf("class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3],
item.rect[4],
item.rect[5],
item.rect[6],
item.rect[7]);
}
else{
printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
item.class_id, item.class_id,
item.confidence, item.confidence,
item.rect[0], item.rect[0],
...@@ -212,6 +226,11 @@ void PredictVideo(const std::string& video_path, ...@@ -212,6 +226,11 @@ void PredictVideo(const std::string& video_path,
item.rect[2], item.rect[2],
item.rect[3]); item.rect[3]);
} }
}
cv::Mat out_im = PaddleDetection::VisualizeResult(
frame, result, labels, colormap, is_rbox);
video_out.write(out_im); video_out.write(out_im);
frame_id += 1; frame_id += 1;
} }
...@@ -231,11 +250,27 @@ void PredictImage(const std::vector<std::string> all_img_list, ...@@ -231,11 +250,27 @@ void PredictImage(const std::vector<std::string> all_img_list,
// Store all detected result // Store all detected result
std::vector<PaddleDetection::ObjectResult> result; std::vector<PaddleDetection::ObjectResult> result;
std::vector<double> det_times; std::vector<double> det_times;
bool is_rbox = false;
if (run_benchmark) { if (run_benchmark) {
det->Predict(im, threshold, 10, 10, &result, &det_times); det->Predict(im, threshold, 10, 10, &result, &det_times);
} else { } else {
det->Predict(im, 0.5, 0, 1, &result, &det_times); det->Predict(im, 0.5, 0, 1, &result, &det_times);
for (const auto& item : result) { for (const auto& item : result) {
if (item.rect.size() > 6){
is_rbox = true;
printf("class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3],
item.rect[4],
item.rect[5],
item.rect[6],
item.rect[7]);
}
else{
printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n", printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
item.class_id, item.class_id,
item.confidence, item.confidence,
...@@ -244,11 +279,12 @@ void PredictImage(const std::vector<std::string> all_img_list, ...@@ -244,11 +279,12 @@ void PredictImage(const std::vector<std::string> all_img_list,
item.rect[2], item.rect[2],
item.rect[3]); item.rect[3]);
} }
}
// Visualization result // Visualization result
auto labels = det->GetLabelList(); auto labels = det->GetLabelList();
auto colormap = PaddleDetection::GenerateColorMap(labels.size()); auto colormap = PaddleDetection::GenerateColorMap(labels.size());
cv::Mat vis_img = PaddleDetection::VisualizeResult( cv::Mat vis_img = PaddleDetection::VisualizeResult(
im, result, labels, colormap); im, result, labels, colormap, is_rbox);
std::vector<int> compression_params; std::vector<int> compression_params;
compression_params.push_back(CV_IMWRITE_JPEG_QUALITY); compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
compression_params.push_back(95); compression_params.push_back(95);
......
...@@ -94,13 +94,10 @@ void ObjectDetector::LoadModel(const std::string& model_dir, ...@@ -94,13 +94,10 @@ void ObjectDetector::LoadModel(const std::string& model_dir,
cv::Mat VisualizeResult(const cv::Mat& img, cv::Mat VisualizeResult(const cv::Mat& img,
const std::vector<ObjectResult>& results, const std::vector<ObjectResult>& results,
const std::vector<std::string>& lable_list, const std::vector<std::string>& lable_list,
const std::vector<int>& colormap) { const std::vector<int>& colormap,
const bool is_rbox=false) {
cv::Mat vis_img = img.clone(); cv::Mat vis_img = img.clone();
for (int i = 0; i < results.size(); ++i) { for (int i = 0; i < results.size(); ++i) {
int w = results[i].rect[1] - results[i].rect[0];
int h = results[i].rect[3] - results[i].rect[2];
cv::Rect roi = cv::Rect(results[i].rect[0], results[i].rect[2], w, h);
// Configure color and text size // Configure color and text size
std::ostringstream oss; std::ostringstream oss;
oss << std::setiosflags(std::ios::fixed) << std::setprecision(4); oss << std::setiosflags(std::ios::fixed) << std::setprecision(4);
...@@ -120,17 +117,37 @@ cv::Mat VisualizeResult(const cv::Mat& img, ...@@ -120,17 +117,37 @@ cv::Mat VisualizeResult(const cv::Mat& img,
thickness, thickness,
nullptr); nullptr);
cv::Point origin; cv::Point origin;
origin.x = roi.x;
origin.y = roi.y; if (is_rbox)
{
// Draw object, text, and background
for (int k=0; k<4; k++)
{
cv::Point pt1 = cv::Point(results[i].rect[(k*2)%8],
results[i].rect[(k*2+1)%8]);
cv::Point pt2 = cv::Point(results[i].rect[(k*2+2)%8],
results[i].rect[(k*2+3)%8]);
cv::line(vis_img, pt1, pt2, roi_color, 2);
}
}
else
{
int w = results[i].rect[1] - results[i].rect[0];
int h = results[i].rect[3] - results[i].rect[2];
cv::Rect roi = cv::Rect(results[i].rect[0], results[i].rect[2], w, h);
// Draw roi object, text, and background
cv::rectangle(vis_img, roi, roi_color, 2);
}
origin.x = results[i].rect[0];
origin.y = results[i].rect[1];
// Configure text background // Configure text background
cv::Rect text_back = cv::Rect(results[i].rect[0], cv::Rect text_back = cv::Rect(results[i].rect[0],
results[i].rect[2] - text_size.height, results[i].rect[2] - text_size.height,
text_size.width, text_size.width,
text_size.height); text_size.height);
// Draw text, and background
// Draw roi object, text, and background
cv::rectangle(vis_img, roi, roi_color, 2);
cv::rectangle(vis_img, text_back, roi_color, -1); cv::rectangle(vis_img, text_back, roi_color, -1);
cv::putText(vis_img, cv::putText(vis_img,
text, text,
...@@ -152,7 +169,8 @@ void ObjectDetector::Preprocess(const cv::Mat& ori_im) { ...@@ -152,7 +169,8 @@ void ObjectDetector::Preprocess(const cv::Mat& ori_im) {
void ObjectDetector::Postprocess( void ObjectDetector::Postprocess(
const cv::Mat& raw_mat, const cv::Mat& raw_mat,
std::vector<ObjectResult>* result) { std::vector<ObjectResult>* result,
bool is_rbox=false) {
result->clear(); result->clear();
int rh = 1; int rh = 1;
int rw = 1; int rw = 1;
...@@ -161,6 +179,33 @@ void ObjectDetector::Postprocess( ...@@ -161,6 +179,33 @@ void ObjectDetector::Postprocess(
rw = raw_mat.cols; rw = raw_mat.cols;
} }
if (is_rbox)
{
int total_size = output_data_.size() / 10;
for (int j = 0; j < total_size; ++j) {
// Class id
int class_id = static_cast<int>(round(output_data_[0 + j * 10]));
// Confidence score
float score = output_data_[1 + j * 10];
int x1 = (output_data_[2 + j * 10] * rw);
int y1 = (output_data_[3 + j * 10] * rh);
int x2 = (output_data_[4 + j * 10] * rw);
int y2 = (output_data_[5 + j * 10] * rh);
int x3 = (output_data_[6 + j * 10] * rw);
int y3 = (output_data_[7 + j * 10] * rh);
int x4 = (output_data_[8 + j * 10] * rw);
int y4 = (output_data_[9 + j * 10] * rh);
if (score > threshold_ && class_id > -1) {
ObjectResult result_item;
result_item.rect = {x1, y1, x2, y2, x3, y3, x4, y4};
result_item.class_id = class_id;
result_item.confidence = score;
result->push_back(result_item);
}
}
}
else
{
int total_size = output_data_.size() / 6; int total_size = output_data_.size() / 6;
for (int j = 0; j < total_size; ++j) { for (int j = 0; j < total_size; ++j) {
// Class id // Class id
...@@ -181,6 +226,7 @@ void ObjectDetector::Postprocess( ...@@ -181,6 +226,7 @@ void ObjectDetector::Postprocess(
result->push_back(result_item); result->push_back(result_item);
} }
} }
}
} }
void ObjectDetector::Predict(const cv::Mat& im, void ObjectDetector::Predict(const cv::Mat& im,
...@@ -231,6 +277,7 @@ void ObjectDetector::Predict(const cv::Mat& im, ...@@ -231,6 +277,7 @@ void ObjectDetector::Predict(const cv::Mat& im,
out_tensor->CopyToCpu(output_data_.data()); out_tensor->CopyToCpu(output_data_.data());
} }
bool is_rbox = false;
auto inference_start = std::chrono::steady_clock::now(); auto inference_start = std::chrono::steady_clock::now();
for (int i = 0; i < repeats; i++) for (int i = 0; i < repeats; i++)
{ {
...@@ -244,6 +291,7 @@ void ObjectDetector::Predict(const cv::Mat& im, ...@@ -244,6 +291,7 @@ void ObjectDetector::Predict(const cv::Mat& im,
for (int j = 0; j < output_shape.size(); ++j) { for (int j = 0; j < output_shape.size(); ++j) {
output_size *= output_shape[j]; output_size *= output_shape[j];
} }
is_rbox = output_shape[output_shape.size()-1] % 10 == 0;
if (output_size < 6) { if (output_size < 6) {
std::cerr << "[WARNING] No object detected." << std::endl; std::cerr << "[WARNING] No object detected." << std::endl;
...@@ -254,7 +302,7 @@ void ObjectDetector::Predict(const cv::Mat& im, ...@@ -254,7 +302,7 @@ void ObjectDetector::Predict(const cv::Mat& im,
auto inference_end = std::chrono::steady_clock::now(); auto inference_end = std::chrono::steady_clock::now();
auto postprocess_start = std::chrono::steady_clock::now(); auto postprocess_start = std::chrono::steady_clock::now();
// Postprocessing result // Postprocessing result
Postprocess(im, result); Postprocess(im, result, is_rbox);
auto postprocess_end = std::chrono::steady_clock::now(); auto postprocess_end = std::chrono::steady_clock::now();
std::chrono::duration<float> preprocess_diff = preprocess_end - preprocess_start; std::chrono::duration<float> preprocess_diff = preprocess_end - preprocess_start;
...@@ -263,6 +311,7 @@ void ObjectDetector::Predict(const cv::Mat& im, ...@@ -263,6 +311,7 @@ void ObjectDetector::Predict(const cv::Mat& im,
times->push_back(double(inference_diff.count() / repeats * 1000)); times->push_back(double(inference_diff.count() / repeats * 1000));
std::chrono::duration<float> postprocess_diff = postprocess_end - postprocess_start; std::chrono::duration<float> postprocess_diff = postprocess_end - postprocess_start;
times->push_back(double(postprocess_diff.count() * 1000)); times->push_back(double(postprocess_diff.count() * 1000));
} }
std::vector<int> GenerateColorMap(int num_class) { std::vector<int> GenerateColorMap(int num_class) {
......
...@@ -37,6 +37,7 @@ SUPPORT_MODELS = { ...@@ -37,6 +37,7 @@ SUPPORT_MODELS = {
'FCOS', 'FCOS',
'SOLOv2', 'SOLOv2',
'TTFNet', 'TTFNet',
'S2ANet',
} }
......
...@@ -130,22 +130,29 @@ def draw_box(im, np_boxes, labels, threshold=0.5): ...@@ -130,22 +130,29 @@ def draw_box(im, np_boxes, labels, threshold=0.5):
for dt in np_boxes: for dt in np_boxes:
clsid, bbox, score = int(dt[0]), dt[2:], dt[1] clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
xmin, ymin, xmax, ymax = bbox
print('class_id:{:d}, confidence:{:.4f}, left_top:[{:.2f},{:.2f}],'
'right_bottom:[{:.2f},{:.2f}]'.format(
int(clsid), score, xmin, ymin, xmax, ymax))
w = xmax - xmin
h = ymax - ymin
if clsid not in clsid2color: if clsid not in clsid2color:
clsid2color[clsid] = color_list[clsid] clsid2color[clsid] = color_list[clsid]
color = tuple(clsid2color[clsid]) color = tuple(clsid2color[clsid])
if len(bbox) == 4:
xmin, ymin, xmax, ymax = bbox
print('class_id:{:d}, confidence:{:.4f}, left_top:[{:.2f},{:.2f}],'
'right_bottom:[{:.2f},{:.2f}]'.format(
int(clsid), score, xmin, ymin, xmax, ymax))
# draw bbox # draw bbox
draw.line( draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)], (xmin, ymin)],
width=draw_thickness, width=draw_thickness,
fill=color) fill=color)
elif len(bbox) == 8:
x1, y1, x2, y2, x3, y3, x4, y4 = bbox
draw.line(
[(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x1, y1)],
width=2,
fill=color)
xmin = min(x1, x2, x3, x4)
ymin = min(y1, y2, y3, y4)
# draw label # draw label
text = "{} {:.4f}".format(labels[clsid], score) text = "{} {:.4f}".format(labels[clsid], score)
......
...@@ -33,7 +33,7 @@ logger = setup_logger(__name__) ...@@ -33,7 +33,7 @@ logger = setup_logger(__name__)
__all__ = [ __all__ = [
'PadBatch', 'BatchRandomResize', 'Gt2YoloTarget', 'Gt2FCOSTarget', 'PadBatch', 'BatchRandomResize', 'Gt2YoloTarget', 'Gt2FCOSTarget',
'Gt2TTFTarget', 'Gt2Solov2Target', 'RboxPadBatch' 'Gt2TTFTarget', 'Gt2Solov2Target'
] ]
...@@ -87,6 +87,12 @@ class PadBatch(BaseOperator): ...@@ -87,6 +87,12 @@ class PadBatch(BaseOperator):
padding_segm[:, :im_h, :im_w] = gt_segm padding_segm[:, :im_h, :im_w] = gt_segm
data['gt_segm'] = padding_segm data['gt_segm'] = padding_segm
if 'gt_rbox2poly' in data and data['gt_rbox2poly'] is not None:
# ploy to rbox
polys = data['gt_rbox2poly']
rbox = bbox_utils.poly2rbox(polys)
data['gt_rbox'] = rbox
return samples return samples
...@@ -739,111 +745,3 @@ class Gt2Solov2Target(BaseOperator): ...@@ -739,111 +745,3 @@ class Gt2Solov2Target(BaseOperator):
data['grid_order{}'.format(idx)] = gt_grid_order data['grid_order{}'.format(idx)] = gt_grid_order
return samples return samples
@register_op
class RboxPadBatch(BaseOperator):
"""
Pad a batch of samples so they can be divisible by a stride.
The layout of each image should be 'CHW'. And convert poly to rbox.
Args:
pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure
height and width is divisible by `pad_to_stride`.
"""
def __init__(self, pad_to_stride=0, pad_gt=False):
super(RboxPadBatch, self).__init__()
self.pad_to_stride = pad_to_stride
self.pad_gt = pad_gt
def __call__(self, samples, context=None):
"""
Args:
samples (list): a batch of sample, each is dict.
"""
coarsest_stride = self.pad_to_stride
max_shape = np.array([data['image'].shape for data in samples]).max(
axis=0)
if coarsest_stride > 0:
max_shape[1] = int(
np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
max_shape[2] = int(
np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride)
for data in samples:
im = data['image']
im_c, im_h, im_w = im.shape[:]
padding_im = np.zeros(
(im_c, max_shape[1], max_shape[2]), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
data['image'] = padding_im
if 'semantic' in data and data['semantic'] is not None:
semantic = data['semantic']
padding_sem = np.zeros(
(1, max_shape[1], max_shape[2]), dtype=np.float32)
padding_sem[:, :im_h, :im_w] = semantic
data['semantic'] = padding_sem
if 'gt_segm' in data and data['gt_segm'] is not None:
gt_segm = data['gt_segm']
padding_segm = np.zeros(
(gt_segm.shape[0], max_shape[1], max_shape[2]),
dtype=np.uint8)
padding_segm[:, :im_h, :im_w] = gt_segm
data['gt_segm'] = padding_segm
if self.pad_gt:
gt_num = []
if 'gt_poly' in data and data['gt_poly'] is not None and len(data[
'gt_poly']) > 0:
pad_mask = True
else:
pad_mask = False
if pad_mask:
poly_num = []
poly_part_num = []
point_num = []
for data in samples:
gt_num.append(data['gt_bbox'].shape[0])
if pad_mask:
poly_num.append(len(data['gt_poly']))
for poly in data['gt_poly']:
poly_part_num.append(int(len(poly)))
for p_p in poly:
point_num.append(int(len(p_p) / 2))
gt_num_max = max(gt_num)
for i, sample in enumerate(samples):
assert 'gt_rbox' in sample
assert 'gt_rbox2poly' in sample
gt_box_data = -np.ones([gt_num_max, 4], dtype=np.float32)
gt_class_data = -np.ones([gt_num_max], dtype=np.int32)
is_crowd_data = np.ones([gt_num_max], dtype=np.int32)
if pad_mask:
poly_num_max = max(poly_num)
poly_part_num_max = max(poly_part_num)
point_num_max = max(point_num)
gt_masks_data = -np.ones(
[poly_num_max, poly_part_num_max, point_num_max, 2],
dtype=np.float32)
gt_num = sample['gt_bbox'].shape[0]
gt_box_data[0:gt_num, :] = sample['gt_bbox']
gt_class_data[0:gt_num] = np.squeeze(sample['gt_class'])
is_crowd_data[0:gt_num] = np.squeeze(sample['is_crowd'])
if pad_mask:
for j, poly in enumerate(sample['gt_poly']):
for k, p_p in enumerate(poly):
pp_np = np.array(p_p).reshape(-1, 2)
gt_masks_data[j, k, :pp_np.shape[0], :] = pp_np
sample['gt_poly'] = gt_masks_data
sample['gt_bbox'] = gt_box_data
sample['gt_class'] = gt_class_data
sample['is_crowd'] = is_crowd_data
# ploy to rbox
polys = sample['gt_rbox2poly']
rbox = bbox_utils.poly_to_rbox(polys)
sample['gt_rbox'] = rbox
return samples
...@@ -2007,7 +2007,7 @@ class Rbox2Poly(BaseOperator): ...@@ -2007,7 +2007,7 @@ class Rbox2Poly(BaseOperator):
x2 = x_ctr + width / 2.0 x2 = x_ctr + width / 2.0
y2 = y_ctr + height / 2.0 y2 = y_ctr + height / 2.0
sample['gt_bbox'] = np.stack([x1, y1, x2, y2], axis=1) sample['gt_bbox'] = np.stack([x1, y1, x2, y2], axis=1)
polys = bbox_utils.rbox2poly(rrects) polys = bbox_utils.rbox2poly_np(rrects)
sample['gt_rbox2poly'] = polys sample['gt_rbox2poly'] = polys
return sample return sample
......
...@@ -376,7 +376,8 @@ class Trainer(object): ...@@ -376,7 +376,8 @@ class Trainer(object):
imid2path = self.dataset.get_imid2path() imid2path = self.dataset.get_imid2path()
anno_file = self.dataset.get_anno() anno_file = self.dataset.get_anno()
clsid2catid, catid2name = get_categories(self.cfg.metric, anno_file) clsid2catid, catid2name = get_categories(
self.cfg.metric, anno_file=anno_file)
# Run Infer # Run Infer
self.status['mode'] = 'test' self.status['mode'] = 'test'
......
...@@ -263,146 +263,7 @@ def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9): ...@@ -263,146 +263,7 @@ def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9):
return iou return iou
def rect2rbox(bboxes): def poly2rbox(polys):
"""
:param bboxes: shape (n, 4) (xmin, ymin, xmax, ymax)
:return: dbboxes: shape (n, 5) (x_ctr, y_ctr, w, h, angle)
"""
bboxes = bboxes.reshape(-1, 4)
num_boxes = bboxes.shape[0]
x_ctr = (bboxes[:, 2] + bboxes[:, 0]) / 2.0
y_ctr = (bboxes[:, 3] + bboxes[:, 1]) / 2.0
edges1 = np.abs(bboxes[:, 2] - bboxes[:, 0])
edges2 = np.abs(bboxes[:, 3] - bboxes[:, 1])
angles = np.zeros([num_boxes], dtype=bboxes.dtype)
inds = edges1 < edges2
rboxes = np.stack((x_ctr, y_ctr, edges1, edges2, angles), axis=1)
rboxes[inds, 2] = edges2[inds]
rboxes[inds, 3] = edges1[inds]
rboxes[inds, 4] = np.pi / 2.0
return rboxes
def delta2rbox(Rrois,
deltas,
means=[0, 0, 0, 0, 0],
stds=[1, 1, 1, 1, 1],
wh_ratio_clip=1e-6):
"""
:param Rrois: (cx, cy, w, h, theta)
:param deltas: (dx, dy, dw, dh, dtheta)
:param means:
:param stds:
:param wh_ratio_clip:
:return:
"""
deltas = paddle.reshape(deltas, [-1, deltas.shape[-1]])
denorm_deltas = deltas * stds + means
dx = denorm_deltas[:, 0]
dy = denorm_deltas[:, 1]
dw = denorm_deltas[:, 2]
dh = denorm_deltas[:, 3]
dangle = denorm_deltas[:, 4]
max_ratio = np.abs(np.log(wh_ratio_clip))
dw = paddle.clip(dw, min=-max_ratio, max=max_ratio)
dh = paddle.clip(dh, min=-max_ratio, max=max_ratio)
Rroi_x = Rrois[:, 0]
Rroi_y = Rrois[:, 1]
Rroi_w = Rrois[:, 2]
Rroi_h = Rrois[:, 3]
Rroi_angle = Rrois[:, 4]
gx = dx * Rroi_w * paddle.cos(Rroi_angle) - dy * Rroi_h * paddle.sin(
Rroi_angle) + Rroi_x
gy = dx * Rroi_w * paddle.sin(Rroi_angle) + dy * Rroi_h * paddle.cos(
Rroi_angle) + Rroi_y
gw = Rroi_w * dw.exp()
gh = Rroi_h * dh.exp()
ga = np.pi * dangle + Rroi_angle
ga = (ga + np.pi / 4) % np.pi - np.pi / 4
ga = paddle.to_tensor(ga)
gw = paddle.to_tensor(gw, dtype='float32')
gh = paddle.to_tensor(gh, dtype='float32')
bboxes = paddle.stack([gx, gy, gw, gh, ga], axis=-1)
return bboxes
def rbox2delta(proposals, gt, means=[0, 0, 0, 0, 0], stds=[1, 1, 1, 1, 1]):
"""
Args:
proposals:
gt:
means: 1x5
stds: 1x5
Returns:
"""
proposals = proposals.astype(np.float64)
PI = np.pi
gt_widths = gt[..., 2]
gt_heights = gt[..., 3]
gt_angle = gt[..., 4]
proposals_widths = proposals[..., 2]
proposals_heights = proposals[..., 3]
proposals_angle = proposals[..., 4]
coord = gt[..., 0:2] - proposals[..., 0:2]
dx = (np.cos(proposals[..., 4]) * coord[..., 0] + np.sin(proposals[..., 4])
* coord[..., 1]) / proposals_widths
dy = (-np.sin(proposals[..., 4]) * coord[..., 0] + np.cos(proposals[..., 4])
* coord[..., 1]) / proposals_heights
dw = np.log(gt_widths / proposals_widths)
dh = np.log(gt_heights / proposals_heights)
da = (gt_angle - proposals_angle)
da = (da + PI / 4) % PI - PI / 4
da /= PI
deltas = np.stack([dx, dy, dw, dh, da], axis=-1)
means = np.array(means, dtype=deltas.dtype)
stds = np.array(stds, dtype=deltas.dtype)
deltas = (deltas - means) / stds
deltas = deltas.astype(np.float32)
return deltas
def bbox_decode(bbox_preds,
anchors,
means=[0, 0, 0, 0, 0],
stds=[1, 1, 1, 1, 1]):
"""decode bbox from deltas
Args:
bbox_preds: [N,H,W,5]
anchors: [H*W,5]
return:
bboxes: [N,H,W,5]
"""
num_imgs, H, W, _ = bbox_preds.shape
bboxes_list = []
for img_id in range(num_imgs):
bbox_pred = bbox_preds[img_id]
# bbox_pred.shape=[5,H,W]
bbox_delta = bbox_pred
bboxes = delta2rbox(
anchors, bbox_delta, means, stds, wh_ratio_clip=1e-6)
bboxes = paddle.reshape(bboxes, [H, W, 5])
bboxes_list.append(bboxes)
return paddle.stack(bboxes_list, axis=0)
def poly_to_rbox(polys):
""" """
poly:[x0,y0,x1,y1,x2,y2,x3,y3] poly:[x0,y0,x1,y1,x2,y2,x3,y3]
to to
...@@ -479,37 +340,16 @@ def get_best_begin_point_single(coordinate): ...@@ -479,37 +340,16 @@ def get_best_begin_point_single(coordinate):
return np.array(combinate[force_flag]).reshape(8) return np.array(combinate[force_flag]).reshape(8)
def rbox2poly_single(rrect): def rbox2poly_np(rrects):
"""
rrect:[x_ctr,y_ctr,w,h,angle]
to
poly:[x0,y0,x1,y1,x2,y2,x3,y3]
"""
x_ctr, y_ctr, width, height, angle = rrect[:5]
tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2
# rect 2x4
rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]])
R = np.array([[np.cos(angle), -np.sin(angle)],
[np.sin(angle), np.cos(angle)]])
# poly
poly = R.dot(rect)
x0, x1, x2, x3 = poly[0, :4] + x_ctr
y0, y1, y2, y3 = poly[1, :4] + y_ctr
poly = np.array([x0, y0, x1, y1, x2, y2, x3, y3], dtype=np.float32)
poly = get_best_begin_point_single(poly)
return poly
def rbox2poly(rrects):
""" """
rrect:[x_ctr,y_ctr,w,h,angle] rrect:[x_ctr,y_ctr,w,h,angle]
to to
poly:[x0,y0,x1,y1,x2,y2,x3,y3] poly:[x0,y0,x1,y1,x2,y2,x3,y3]
""" """
polys = [] polys = []
rrects = rrects.numpy()
for i in range(rrects.shape[0]): for i in range(rrects.shape[0]):
rrect = rrects[i] rrect = rrects[i]
# x_ctr, y_ctr, width, height, angle = rrect[:5]
x_ctr = rrect[0] x_ctr = rrect[0]
y_ctr = rrect[1] y_ctr = rrect[1]
width = rrect[2] width = rrect[2]
...@@ -529,13 +369,13 @@ def rbox2poly(rrects): ...@@ -529,13 +369,13 @@ def rbox2poly(rrects):
return polys return polys
def pd_rbox2poly(rrects): def rbox2poly(rrects):
""" """
rrect:[x_ctr,y_ctr,w,h,angle] rrect:[x_ctr,y_ctr,w,h,angle]
to to
poly:[x0,y0,x1,y1,x2,y2,x3,y3] poly:[x0,y0,x1,y1,x2,y2,x3,y3]
""" """
N = rrects.shape[0] N = paddle.shape(rrects)[0]
x_ctr = rrects[:, 0] x_ctr = rrects[:, 0]
y_ctr = rrects[:, 1] y_ctr = rrects[:, 1]
...@@ -561,14 +401,10 @@ def pd_rbox2poly(rrects): ...@@ -561,14 +401,10 @@ def pd_rbox2poly(rrects):
polys = paddle.transpose(polys, [2, 1, 0]) polys = paddle.transpose(polys, [2, 1, 0])
polys = paddle.reshape(polys, [-1, N]) polys = paddle.reshape(polys, [-1, N])
polys = paddle.transpose(polys, [1, 0]) polys = paddle.transpose(polys, [1, 0])
polys[:, 0] += x_ctr
polys[:, 2] += x_ctr tmp = paddle.stack(
polys[:, 4] += x_ctr [x_ctr, y_ctr, x_ctr, y_ctr, x_ctr, y_ctr, x_ctr, y_ctr], axis=1)
polys[:, 6] += x_ctr polys = polys + tmp
polys[:, 1] += y_ctr
polys[:, 3] += y_ctr
polys[:, 5] += y_ctr
polys[:, 7] += y_ctr
return polys return polys
......
...@@ -17,26 +17,21 @@ import paddle.nn as nn ...@@ -17,26 +17,21 @@ import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.nn.initializer import Normal, Constant from paddle.nn.initializer import Normal, Constant
from ppdet.core.workspace import register from ppdet.core.workspace import register
from ppdet.modeling import ops
from ppdet.modeling import bbox_utils from ppdet.modeling import bbox_utils
from ppdet.modeling.proposal_generator.target_layer import RBoxAssigner from ppdet.modeling.proposal_generator.target_layer import RBoxAssigner
import numpy as np import numpy as np
class S2ANetAnchorGenerator(object): class S2ANetAnchorGenerator(nn.Layer):
""" """
S2ANetAnchorGenerator by np AnchorGenerator by paddle
""" """
def __init__(self, def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
base_size=8, super(S2ANetAnchorGenerator, self).__init__()
scales=1.0,
ratios=1.0,
scale_major=True,
ctr=None):
self.base_size = base_size self.base_size = base_size
self.scales = scales self.scales = paddle.to_tensor(scales)
self.ratios = ratios self.ratios = paddle.to_tensor(ratios)
self.scale_major = scale_major self.scale_major = scale_major
self.ctr = ctr self.ctr = ctr
self.base_anchors = self.gen_base_anchors() self.base_anchors = self.gen_base_anchors()
...@@ -54,7 +49,7 @@ class S2ANetAnchorGenerator(object): ...@@ -54,7 +49,7 @@ class S2ANetAnchorGenerator(object):
else: else:
x_ctr, y_ctr = self.ctr x_ctr, y_ctr = self.ctr
h_ratios = np.sqrt(self.ratios) h_ratios = paddle.sqrt(self.ratios)
w_ratios = 1 / h_ratios w_ratios = 1 / h_ratios
if self.scale_major: if self.scale_major:
ws = (w * w_ratios[:] * self.scales[:]).reshape([-1]) ws = (w * w_ratios[:] * self.scales[:]).reshape([-1])
...@@ -63,61 +58,51 @@ class S2ANetAnchorGenerator(object): ...@@ -63,61 +58,51 @@ class S2ANetAnchorGenerator(object):
ws = (w * self.scales[:] * w_ratios[:]).reshape([-1]) ws = (w * self.scales[:] * w_ratios[:]).reshape([-1])
hs = (h * self.scales[:] * h_ratios[:]).reshape([-1]) hs = (h * self.scales[:] * h_ratios[:]).reshape([-1])
# yapf: disable base_anchors = paddle.stack(
base_anchors = np.stack(
[ [
x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1), x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1) x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
], ],
axis=-1) axis=-1)
base_anchors = np.round(base_anchors) base_anchors = paddle.round(base_anchors)
# yapf: enable
return base_anchors return base_anchors
def _meshgrid(self, x, y, row_major=True): def _meshgrid(self, x, y, row_major=True):
xx, yy = np.meshgrid(x, y) yy, xx = paddle.meshgrid(x, y)
xx = xx.reshape(-1) yy = yy.reshape([-1])
yy = yy.reshape(-1) xx = xx.reshape([-1])
if row_major: if row_major:
return xx, yy return xx, yy
else: else:
return yy, xx return yy, xx
def grid_anchors(self, featmap_size, stride=16): def forward(self, featmap_size, stride=16):
# featmap_size*stride project it to original area # featmap_size*stride project it to original area
base_anchors = self.base_anchors base_anchors = self.base_anchors
feat_h, feat_w = featmap_size
shift_x = np.arange(0, feat_w, 1, 'int32') * stride feat_h = featmap_size[0]
shift_y = np.arange(0, feat_h, 1, 'int32') * stride feat_w = featmap_size[1]
shift_x = paddle.arange(0, feat_w, 1, 'int32') * stride
shift_y = paddle.arange(0, feat_h, 1, 'int32') * stride
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1) shifts = paddle.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
# shifts = shifts.type_as(base_anchors)
# first feat_w elements correspond to the first row of shifts all_anchors = base_anchors[:, :] + shifts[:, :]
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get all_anchors = all_anchors.reshape([feat_h * feat_w, 4])
# shifted anchors (K, A, 4), reshape to (K*A, 4)
#all_anchors = base_anchors[:, :] + shifts[:, :]
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
# all_anchors = all_anchors.reshape([-1, 4])
# first A rows correspond to A anchors of (0, 0) in feature map,
# then (0, 1), (0, 2), ...
return all_anchors return all_anchors
def valid_flags(self, featmap_size, valid_size): def valid_flags(self, featmap_size, valid_size):
feat_h, feat_w = featmap_size feat_h, feat_w = featmap_size
valid_h, valid_w = valid_size valid_h, valid_w = valid_size
assert valid_h <= feat_h and valid_w <= feat_w assert valid_h <= feat_h and valid_w <= feat_w
valid_x = np.zeros([feat_w], dtype='uint8') valid_x = paddle.zeros([feat_w], dtype='uint8')
valid_y = np.zeros([feat_h], dtype='uint8') valid_y = paddle.zeros([feat_h], dtype='uint8')
valid_x[:valid_w] = 1 valid_x[:valid_w] = 1
valid_y[:valid_h] = 1 valid_y[:valid_h] = 1
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
valid = valid_xx & valid_yy valid = valid_xx & valid_yy
valid = valid.reshape([-1]) valid = valid[:, None].expand(
[valid.size(0), self.num_base_anchors]).reshape([-1])
# valid = valid[:, None].expand(
# [valid.size(0), self.num_base_anchors]).reshape([-1])
return valid return valid
...@@ -240,8 +225,8 @@ class S2ANetHead(nn.Layer): ...@@ -240,8 +225,8 @@ class S2ANetHead(nn.Layer):
anchor_strides=[8, 16, 32, 64, 128], anchor_strides=[8, 16, 32, 64, 128],
anchor_scales=[4], anchor_scales=[4],
anchor_ratios=[1.0], anchor_ratios=[1.0],
target_means=(.0, .0, .0, .0, .0), target_means=0.0,
target_stds=(1.0, 1.0, 1.0, 1.0, 1.0), target_stds=1.0,
align_conv_type='AlignConv', align_conv_type='AlignConv',
align_conv_size=3, align_conv_size=3,
use_sigmoid_cls=True, use_sigmoid_cls=True,
...@@ -278,6 +263,8 @@ class S2ANetHead(nn.Layer): ...@@ -278,6 +263,8 @@ class S2ANetHead(nn.Layer):
self.anchor_generators.append( self.anchor_generators.append(
S2ANetAnchorGenerator(anchor_base, anchor_scales, S2ANetAnchorGenerator(anchor_base, anchor_scales,
anchor_ratios)) anchor_ratios))
self.anchor_generators = paddle.nn.LayerList(self.anchor_generators)
self.add_sublayer('s2anet_anchor_gen', self.anchor_generators)
self.fam_cls_convs = nn.Sequential() self.fam_cls_convs = nn.Sequential()
self.fam_reg_convs = nn.Sequential() self.fam_reg_convs = nn.Sequential()
...@@ -412,9 +399,9 @@ class S2ANetHead(nn.Layer): ...@@ -412,9 +399,9 @@ class S2ANetHead(nn.Layer):
weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)), weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
bias_attr=ParamAttr(initializer=Constant(0))) bias_attr=ParamAttr(initializer=Constant(0)))
self.base_anchors = dict() self.featmap_size_list = []
self.featmap_sizes = dict() self.init_anchors_list = []
self.base_anchors = dict() self.rbox_anchors_list = []
self.refine_anchor_list = [] self.refine_anchor_list = []
def forward(self, feats): def forward(self, feats):
...@@ -424,13 +411,27 @@ class S2ANetHead(nn.Layer): ...@@ -424,13 +411,27 @@ class S2ANetHead(nn.Layer):
odm_reg_branch_list = [] odm_reg_branch_list = []
odm_cls_branch_list = [] odm_cls_branch_list = []
self.featmap_sizes = dict() fam_reg1_branch_list = []
self.base_anchors = dict()
self.featmap_size_list = []
self.init_anchors_list = []
self.rbox_anchors_list = []
self.refine_anchor_list = [] self.refine_anchor_list = []
for i, feat in enumerate(feats): for i, feat in enumerate(feats):
fam_cls_feat = self.fam_cls_convs(feat) # prepare anchor
featmap_size = paddle.shape(feat)[-2:]
self.featmap_size_list.append(featmap_size)
init_anchors = self.anchor_generators[i](featmap_size,
self.anchor_strides[i])
init_anchors = paddle.reshape(
init_anchors, [featmap_size[0] * featmap_size[1], 4])
self.init_anchors_list.append(init_anchors)
rbox_anchors = self.rect2rbox(init_anchors)
self.rbox_anchors_list.append(rbox_anchors)
fam_cls_feat = self.fam_cls_convs(feat)
fam_cls = self.fam_cls(fam_cls_feat) fam_cls = self.fam_cls(fam_cls_feat)
# [N, CLS, H, W] --> [N, H, W, CLS] # [N, CLS, H, W] --> [N, H, W, CLS]
fam_cls = fam_cls.transpose([0, 2, 3, 1]) fam_cls = fam_cls.transpose([0, 2, 3, 1])
...@@ -446,29 +447,13 @@ class S2ANetHead(nn.Layer): ...@@ -446,29 +447,13 @@ class S2ANetHead(nn.Layer):
fam_reg_reshape = paddle.reshape(fam_reg, [fam_reg.shape[0], -1, 5]) fam_reg_reshape = paddle.reshape(fam_reg, [fam_reg.shape[0], -1, 5])
fam_reg_branch_list.append(fam_reg_reshape) fam_reg_branch_list.append(fam_reg_reshape)
# prepare anchor # refine anchors
featmap_size = feat.shape[-2:]
self.featmap_sizes[i] = featmap_size
init_anchors = self.anchor_generators[i].grid_anchors(
featmap_size, self.anchor_strides[i])
init_anchors = bbox_utils.rect2rbox(init_anchors)
self.base_anchors[(i, featmap_size[0])] = init_anchors
fam_reg1 = fam_reg.clone() fam_reg1 = fam_reg.clone()
fam_reg1.stop_gradient = True fam_reg1.stop_gradient = True
pd_target_means = paddle.to_tensor( rbox_anchors.stop_gradient = True
np.array( fam_reg1_branch_list.append(fam_reg1)
self.target_means, dtype=np.float32), dtype='float32') refine_anchor = self.bbox_decode(
pd_target_stds = paddle.to_tensor( fam_reg1, rbox_anchors, self.target_stds, self.target_means)
np.array(
self.target_stds, dtype=np.float32), dtype='float32')
pd_init_anchors = paddle.to_tensor(
np.array(
init_anchors, dtype=np.float32), dtype='float32')
refine_anchor = bbox_utils.bbox_decode(
fam_reg1, pd_init_anchors, pd_target_means, pd_target_stds)
self.refine_anchor_list.append(refine_anchor) self.refine_anchor_list.append(refine_anchor)
if self.align_conv_type == 'AlignConv': if self.align_conv_type == 'AlignConv':
...@@ -508,6 +493,87 @@ class S2ANetHead(nn.Layer): ...@@ -508,6 +493,87 @@ class S2ANetHead(nn.Layer):
odm_cls_branch_list, odm_reg_branch_list) odm_cls_branch_list, odm_reg_branch_list)
return self.s2anet_head_out return self.s2anet_head_out
def rect2rbox(self, bboxes):
"""
:param bboxes: shape (n, 4) (xmin, ymin, xmax, ymax)
:return: dbboxes: shape (n, 5) (x_ctr, y_ctr, w, h, angle)
"""
num_boxes = paddle.shape(bboxes)[0]
x_ctr = (bboxes[:, 2] + bboxes[:, 0]) / 2.0
y_ctr = (bboxes[:, 3] + bboxes[:, 1]) / 2.0
edges1 = paddle.abs(bboxes[:, 2] - bboxes[:, 0])
edges2 = paddle.abs(bboxes[:, 3] - bboxes[:, 1])
rbox_w = paddle.maximum(edges1, edges2)
rbox_h = paddle.minimum(edges1, edges2)
# set angle
inds = edges1 < edges2
inds = paddle.cast(inds, 'int32')
inds1 = inds * paddle.arange(0, num_boxes)
rboxes_angle = inds1 * np.pi / 2.0
rboxes = paddle.stack(
(x_ctr, y_ctr, rbox_w, rbox_h, rboxes_angle), axis=1)
return rboxes
# deltas to rbox
def delta2rbox(self, rrois, deltas, means, stds, wh_ratio_clip=1e-6):
"""
:param rrois: (cx, cy, w, h, theta)
:param deltas: (dx, dy, dw, dh, dtheta)
:param means: means of anchor
:param stds: stds of anchor
:param wh_ratio_clip: clip threshold of wh_ratio
:return:
"""
deltas = paddle.reshape(deltas, [-1, 5])
rrois = paddle.reshape(rrois, [-1, 5])
pd_means = paddle.ones(shape=[5]) * means
pd_stds = paddle.ones(shape=[5]) * stds
denorm_deltas = deltas * pd_stds + pd_means
dx = denorm_deltas[:, 0]
dy = denorm_deltas[:, 1]
dw = denorm_deltas[:, 2]
dh = denorm_deltas[:, 3]
dangle = denorm_deltas[:, 4]
max_ratio = np.abs(np.log(wh_ratio_clip))
dw = paddle.clip(dw, min=-max_ratio, max=max_ratio)
dh = paddle.clip(dh, min=-max_ratio, max=max_ratio)
rroi_x = rrois[:, 0]
rroi_y = rrois[:, 1]
rroi_w = rrois[:, 2]
rroi_h = rrois[:, 3]
rroi_angle = rrois[:, 4]
gx = dx * rroi_w * paddle.cos(rroi_angle) - dy * rroi_h * paddle.sin(
rroi_angle) + rroi_x
gy = dx * rroi_w * paddle.sin(rroi_angle) + dy * rroi_h * paddle.cos(
rroi_angle) + rroi_y
gw = rroi_w * dw.exp()
gh = rroi_h * dh.exp()
ga = np.pi * dangle + rroi_angle
ga = (ga + np.pi / 4) % np.pi - np.pi / 4
bboxes = paddle.stack([gx, gy, gw, gh, ga], axis=-1)
return bboxes
def bbox_decode(self, bbox_preds, anchors, stds, means, wh_ratio_clip=1e-6):
"""decode bbox from deltas
Args:
bbox_preds: bbox_preds, shape=[N,H,W,5]
anchors: anchors, shape=[H,W,5]
return:
bboxes: return decoded bboxes, shape=[N*H*W,5]
"""
num_imgs, H, W, _ = bbox_preds.shape
bbox_delta = paddle.reshape(bbox_preds, [-1, 5])
bboxes = self.delta2rbox(anchors, bbox_delta, means, stds,
wh_ratio_clip)
return bboxes
def get_prediction(self, nms_pre): def get_prediction(self, nms_pre):
refine_anchors = self.refine_anchor_list refine_anchors = self.refine_anchor_list
fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = self.s2anet_head_out fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = self.s2anet_head_out
...@@ -518,6 +584,7 @@ class S2ANetHead(nn.Layer): ...@@ -518,6 +584,7 @@ class S2ANetHead(nn.Layer):
nms_pre, nms_pre,
cls_out_channels=self.cls_out_channels, cls_out_channels=self.cls_out_channels,
use_sigmoid_cls=self.use_sigmoid_cls) use_sigmoid_cls=self.use_sigmoid_cls)
return pred_scores, pred_bboxes return pred_scores, pred_bboxes
def smooth_l1_loss(self, pred, label, delta=1.0 / 9.0): def smooth_l1_loss(self, pred, label, delta=1.0 / 9.0):
...@@ -536,41 +603,25 @@ class S2ANetHead(nn.Layer): ...@@ -536,41 +603,25 @@ class S2ANetHead(nn.Layer):
return loss return loss
def get_fam_loss(self, fam_target, s2anet_head_out): def get_fam_loss(self, fam_target, s2anet_head_out):
(labels, label_weights, bbox_targets, bbox_weights, pos_inds, (feat_labels, feat_label_weights, feat_bbox_targets, feat_bbox_weights,
neg_inds) = fam_target pos_inds, neg_inds) = fam_target
fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out fam_cls_score, fam_bbox_pred = s2anet_head_out
fam_cls_losses = [] # step1: sample count
fam_bbox_losses = []
st_idx = 0
featmap_sizes = [self.featmap_sizes[e] for e in self.featmap_sizes]
num_total_samples = len(pos_inds) + len( num_total_samples = len(pos_inds) + len(
neg_inds) if self.sampling else len(pos_inds) neg_inds) if self.sampling else len(pos_inds)
num_total_samples = max(1, num_total_samples) num_total_samples = max(1, num_total_samples)
for idx, feat_size in enumerate(featmap_sizes):
feat_anchor_num = feat_size[0] * feat_size[1]
# step1: get data
feat_labels = labels[st_idx:st_idx + feat_anchor_num]
feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num]
feat_bbox_targets = bbox_targets[st_idx:st_idx + feat_anchor_num, :]
feat_bbox_weights = bbox_weights[st_idx:st_idx + feat_anchor_num, :]
st_idx += feat_anchor_num
# step2: calc cls loss # step2: calc cls loss
feat_labels = feat_labels.reshape(-1) feat_labels = feat_labels.reshape(-1)
feat_label_weights = feat_label_weights.reshape(-1) feat_label_weights = feat_label_weights.reshape(-1)
fam_cls_score = fam_cls_branch_list[idx]
fam_cls_score = paddle.squeeze(fam_cls_score, axis=0) fam_cls_score = paddle.squeeze(fam_cls_score, axis=0)
fam_cls_score1 = fam_cls_score fam_cls_score1 = fam_cls_score
# gt_classes 0~14(data), feat_labels 0~14, sigmoid_focal_loss need class>=1 # gt_classes 0~14(data), feat_labels 0~14, sigmoid_focal_loss need class>=1
feat_labels = feat_labels + 1
feat_labels = paddle.to_tensor(feat_labels) feat_labels = paddle.to_tensor(feat_labels)
feat_labels_one_hot = paddle.nn.functional.one_hot( feat_labels_one_hot = F.one_hot(feat_labels, self.cls_out_channels + 1)
feat_labels, self.cls_out_channels + 1)
feat_labels_one_hot = feat_labels_one_hot[:, 1:] feat_labels_one_hot = feat_labels_one_hot[:, 1:]
feat_labels_one_hot.stop_gradient = True feat_labels_one_hot.stop_gradient = True
...@@ -592,15 +643,11 @@ class S2ANetHead(nn.Layer): ...@@ -592,15 +643,11 @@ class S2ANetHead(nn.Layer):
fam_cls = fam_cls * feat_label_weights fam_cls = fam_cls * feat_label_weights
fam_cls_total = paddle.sum(fam_cls) fam_cls_total = paddle.sum(fam_cls)
fam_cls_losses.append(fam_cls_total)
# step3: regression loss # step3: regression loss
fam_bbox_pred = fam_reg_branch_list[idx]
feat_bbox_targets = paddle.to_tensor( feat_bbox_targets = paddle.to_tensor(
feat_bbox_targets, dtype='float32', stop_gradient=True) feat_bbox_targets, dtype='float32', stop_gradient=True)
feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5]) feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
fam_bbox_pred = fam_reg_branch_list[idx]
fam_bbox_pred = paddle.squeeze(fam_bbox_pred, axis=0) fam_bbox_pred = paddle.squeeze(fam_bbox_pred, axis=0)
fam_bbox_pred = paddle.reshape(fam_bbox_pred, [-1, 5]) fam_bbox_pred = paddle.reshape(fam_bbox_pred, [-1, 5])
fam_bbox = self.smooth_l1_loss(fam_bbox_pred, feat_bbox_targets) fam_bbox = self.smooth_l1_loss(fam_bbox_pred, feat_bbox_targets)
...@@ -612,55 +659,39 @@ class S2ANetHead(nn.Layer): ...@@ -612,55 +659,39 @@ class S2ANetHead(nn.Layer):
fam_bbox = fam_bbox * feat_bbox_weights fam_bbox = fam_bbox * feat_bbox_weights
fam_bbox_total = paddle.sum(fam_bbox) / num_total_samples fam_bbox_total = paddle.sum(fam_bbox) / num_total_samples
fam_bbox_losses.append(fam_bbox_total)
fam_cls_loss = paddle.add_n(fam_cls_losses)
fam_cls_loss_weight = paddle.to_tensor( fam_cls_loss_weight = paddle.to_tensor(
self.cls_loss_weight[0], dtype='float32', stop_gradient=True) self.cls_loss_weight[0], dtype='float32', stop_gradient=True)
fam_cls_loss = fam_cls_loss * fam_cls_loss_weight fam_cls_loss = fam_cls_total * fam_cls_loss_weight
fam_reg_loss = paddle.add_n(fam_bbox_losses) fam_reg_loss = paddle.add_n(fam_bbox_total)
return fam_cls_loss, fam_reg_loss return fam_cls_loss, fam_reg_loss
def get_odm_loss(self, odm_target, s2anet_head_out): def get_odm_loss(self, odm_target, s2anet_head_out):
(labels, label_weights, bbox_targets, bbox_weights, pos_inds, (feat_labels, feat_label_weights, feat_bbox_targets, feat_bbox_weights,
neg_inds) = odm_target pos_inds, neg_inds) = odm_target
fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out odm_cls_score, odm_bbox_pred = s2anet_head_out
odm_cls_losses = [] # step1: sample count
odm_bbox_losses = []
st_idx = 0
featmap_sizes = [self.featmap_sizes[e] for e in self.featmap_sizes]
num_total_samples = len(pos_inds) + len( num_total_samples = len(pos_inds) + len(
neg_inds) if self.sampling else len(pos_inds) neg_inds) if self.sampling else len(pos_inds)
num_total_samples = max(1, num_total_samples) num_total_samples = max(1, num_total_samples)
for idx, feat_size in enumerate(featmap_sizes):
feat_anchor_num = feat_size[0] * feat_size[1]
# step1: get data
feat_labels = labels[st_idx:st_idx + feat_anchor_num]
feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num]
feat_bbox_targets = bbox_targets[st_idx:st_idx + feat_anchor_num, :]
feat_bbox_weights = bbox_weights[st_idx:st_idx + feat_anchor_num, :]
st_idx += feat_anchor_num
# step2: calc cls loss # step2: calc cls loss
feat_labels = feat_labels.reshape(-1) feat_labels = feat_labels.reshape(-1)
feat_label_weights = feat_label_weights.reshape(-1) feat_label_weights = feat_label_weights.reshape(-1)
odm_cls_score = odm_cls_branch_list[idx]
odm_cls_score = paddle.squeeze(odm_cls_score, axis=0) odm_cls_score = paddle.squeeze(odm_cls_score, axis=0)
odm_cls_score1 = odm_cls_score odm_cls_score1 = odm_cls_score
# gt_classes 0~14(data), feat_labels 0~14, sigmoid_focal_loss need class>=1 # gt_classes 0~14(data), feat_labels 0~14, sigmoid_focal_loss need class>=1
# for debug 0426
feat_labels = feat_labels + 1
feat_labels = paddle.to_tensor(feat_labels) feat_labels = paddle.to_tensor(feat_labels)
feat_labels_one_hot = paddle.nn.functional.one_hot( feat_labels_one_hot = F.one_hot(feat_labels, self.cls_out_channels + 1)
feat_labels, self.cls_out_channels + 1)
feat_labels_one_hot = feat_labels_one_hot[:, 1:] feat_labels_one_hot = feat_labels_one_hot[:, 1:]
feat_labels_one_hot.stop_gradient = True feat_labels_one_hot.stop_gradient = True
num_total_samples = paddle.to_tensor( num_total_samples = paddle.to_tensor(
num_total_samples, dtype='float32', stop_gradient=True) num_total_samples, dtype='float32', stop_gradient=True)
odm_cls = F.sigmoid_focal_loss( odm_cls = F.sigmoid_focal_loss(
odm_cls_score1, odm_cls_score1,
feat_labels_one_hot, feat_labels_one_hot,
...@@ -671,20 +702,16 @@ class S2ANetHead(nn.Layer): ...@@ -671,20 +702,16 @@ class S2ANetHead(nn.Layer):
feat_label_weights.shape[0], 1) feat_label_weights.shape[0], 1)
feat_label_weights = np.repeat( feat_label_weights = np.repeat(
feat_label_weights, self.cls_out_channels, axis=1) feat_label_weights, self.cls_out_channels, axis=1)
feat_label_weights = paddle.to_tensor(feat_label_weights) feat_label_weights = paddle.to_tensor(
feat_label_weights.stop_gradient = True feat_label_weights, stop_gradient=True)
odm_cls = odm_cls * feat_label_weights odm_cls = odm_cls * feat_label_weights
odm_cls_total = paddle.sum(odm_cls) odm_cls_total = paddle.sum(odm_cls)
odm_cls_losses.append(odm_cls_total)
# # step3: regression loss # step3: regression loss
feat_bbox_targets = paddle.to_tensor( feat_bbox_targets = paddle.to_tensor(
feat_bbox_targets, dtype='float32') feat_bbox_targets, dtype='float32', stop_gradient=True)
feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5]) feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
feat_bbox_targets.stop_gradient = True
odm_bbox_pred = odm_reg_branch_list[idx]
odm_bbox_pred = paddle.squeeze(odm_bbox_pred, axis=0) odm_bbox_pred = paddle.squeeze(odm_bbox_pred, axis=0)
odm_bbox_pred = paddle.reshape(odm_bbox_pred, [-1, 5]) odm_bbox_pred = paddle.reshape(odm_bbox_pred, [-1, 5])
odm_bbox = self.smooth_l1_loss(odm_bbox_pred, feat_bbox_targets) odm_bbox = self.smooth_l1_loss(odm_bbox_pred, feat_bbox_targets)
...@@ -695,13 +722,11 @@ class S2ANetHead(nn.Layer): ...@@ -695,13 +722,11 @@ class S2ANetHead(nn.Layer):
feat_bbox_weights, stop_gradient=True) feat_bbox_weights, stop_gradient=True)
odm_bbox = odm_bbox * feat_bbox_weights odm_bbox = odm_bbox * feat_bbox_weights
odm_bbox_total = paddle.sum(odm_bbox) / num_total_samples odm_bbox_total = paddle.sum(odm_bbox) / num_total_samples
odm_bbox_losses.append(odm_bbox_total)
odm_cls_loss = paddle.add_n(odm_cls_losses)
odm_cls_loss_weight = paddle.to_tensor( odm_cls_loss_weight = paddle.to_tensor(
self.cls_loss_weight[1], dtype='float32', stop_gradient=True) self.cls_loss_weight[0], dtype='float32', stop_gradient=True)
odm_cls_loss = odm_cls_loss * odm_cls_loss_weight odm_cls_loss = odm_cls_total * odm_cls_loss_weight
odm_reg_loss = paddle.add_n(odm_bbox_losses) odm_reg_loss = paddle.add_n(odm_bbox_total)
return odm_cls_loss, odm_reg_loss return odm_cls_loss, odm_reg_loss
def get_loss(self, inputs): def get_loss(self, inputs):
...@@ -723,46 +748,38 @@ class S2ANetHead(nn.Layer): ...@@ -723,46 +748,38 @@ class S2ANetHead(nn.Layer):
is_crowd = inputs['is_crowd'][im_id].numpy() is_crowd = inputs['is_crowd'][im_id].numpy()
gt_labels = gt_labels + 1 gt_labels = gt_labels + 1
# featmap_sizes
featmap_sizes = [self.featmap_sizes[e] for e in self.featmap_sizes]
anchors_list, valid_flag_list = self.get_init_anchors(featmap_sizes,
np_im_shape)
anchors_list_all = []
for ii, anchor in enumerate(anchors_list):
anchor = anchor.reshape(-1, 4)
anchor = bbox_utils.rect2rbox(anchor)
anchors_list_all.extend(anchor)
anchors_list_all = np.array(anchors_list_all)
# get im_feat
fam_cls_feats_list = [e[im_id] for e in self.s2anet_head_out[0]]
fam_reg_feats_list = [e[im_id] for e in self.s2anet_head_out[1]]
odm_cls_feats_list = [e[im_id] for e in self.s2anet_head_out[2]]
odm_reg_feats_list = [e[im_id] for e in self.s2anet_head_out[3]]
im_s2anet_head_out = (fam_cls_feats_list, fam_reg_feats_list,
odm_cls_feats_list, odm_reg_feats_list)
# FAM # FAM
im_fam_target = self.anchor_assign(anchors_list_all, gt_bboxes, for idx, rbox_anchors in enumerate(self.rbox_anchors_list):
rbox_anchors = rbox_anchors.numpy()
rbox_anchors = rbox_anchors.reshape(-1, 5)
im_fam_target = self.anchor_assign(rbox_anchors, gt_bboxes,
gt_labels, is_crowd) gt_labels, is_crowd)
if im_fam_target is not None: # feat
fam_cls_feat = self.s2anet_head_out[0][idx][im_id]
fam_reg_feat = self.s2anet_head_out[1][idx][im_id]
im_s2anet_fam_feat = (fam_cls_feat, fam_reg_feat)
im_fam_cls_loss, im_fam_reg_loss = self.get_fam_loss( im_fam_cls_loss, im_fam_reg_loss = self.get_fam_loss(
im_fam_target, im_s2anet_head_out) im_fam_target, im_s2anet_fam_feat)
fam_cls_loss_lst.append(im_fam_cls_loss) fam_cls_loss_lst.append(im_fam_cls_loss)
fam_reg_loss_lst.append(im_fam_reg_loss) fam_reg_loss_lst.append(im_fam_reg_loss)
# ODM # ODM
refine_anchors_list, valid_flag_list = self.get_refine_anchors( for idx, refine_anchors in enumerate(self.refine_anchor_list):
featmap_sizes, image_shape=np_im_shape) refine_anchors = refine_anchors.numpy()
refine_anchors_list = np.array(refine_anchors_list) refine_anchors = refine_anchors.reshape(-1, 5)
im_odm_target = self.anchor_assign(refine_anchors_list, gt_bboxes, im_odm_target = self.anchor_assign(refine_anchors, gt_bboxes,
gt_labels, is_crowd) gt_labels, is_crowd)
if im_odm_target is not None: odm_cls_feat = self.s2anet_head_out[2][idx][im_id]
odm_reg_feat = self.s2anet_head_out[3][idx][im_id]
im_s2anet_odm_feat = (odm_cls_feat, odm_reg_feat)
im_odm_cls_loss, im_odm_reg_loss = self.get_odm_loss( im_odm_cls_loss, im_odm_reg_loss = self.get_odm_loss(
im_odm_target, im_s2anet_head_out) im_odm_target, im_s2anet_odm_feat)
odm_cls_loss_lst.append(im_odm_cls_loss) odm_cls_loss_lst.append(im_odm_cls_loss)
odm_reg_loss_lst.append(im_odm_reg_loss) odm_reg_loss_lst.append(im_odm_reg_loss)
fam_cls_loss = paddle.add_n(fam_cls_loss_lst) fam_cls_loss = paddle.add_n(fam_cls_loss_lst)
fam_reg_loss = paddle.add_n(fam_reg_loss_lst) fam_reg_loss = paddle.add_n(fam_reg_loss_lst)
odm_cls_loss = paddle.add_n(odm_cls_loss_lst) odm_cls_loss = paddle.add_n(odm_cls_loss_lst)
...@@ -774,65 +791,6 @@ class S2ANetHead(nn.Layer): ...@@ -774,65 +791,6 @@ class S2ANetHead(nn.Layer):
'odm_reg_loss': odm_reg_loss 'odm_reg_loss': odm_reg_loss
} }
def get_init_anchors(self, featmap_sizes, image_shape):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
image_shape (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
anchor_list = []
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
anchor_list.append(anchors)
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w = image_shape
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
valid_flag_list.append(flags)
return anchor_list, valid_flag_list
def get_refine_anchors(self, featmap_sizes, image_shape):
num_levels = len(featmap_sizes)
refine_anchors_list = []
for i in range(num_levels):
refine_anchor = self.refine_anchor_list[i]
refine_anchor = paddle.squeeze(refine_anchor, axis=0)
refine_anchor = refine_anchor.numpy()
refine_anchor = np.reshape(refine_anchor,
[-1, refine_anchor.shape[-1]])
refine_anchors_list.extend(refine_anchor)
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w = image_shape
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
valid_flag_list.append(flags)
return refine_anchors_list, valid_flag_list
def get_bboxes(self, cls_score_list, bbox_pred_list, mlvl_anchors, nms_pre, def get_bboxes(self, cls_score_list, bbox_pred_list, mlvl_anchors, nms_pre,
cls_out_channels, use_sigmoid_cls): cls_out_channels, use_sigmoid_cls):
assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors) assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
...@@ -866,14 +824,8 @@ class S2ANetHead(nn.Layer): ...@@ -866,14 +824,8 @@ class S2ANetHead(nn.Layer):
bbox_pred = paddle.gather(bbox_pred, topk_inds) bbox_pred = paddle.gather(bbox_pred, topk_inds)
scores = paddle.gather(scores, topk_inds) scores = paddle.gather(scores, topk_inds)
pd_target_means = paddle.to_tensor( bboxes = self.delta2rbox(anchors, bbox_pred, self.target_means,
np.array( self.target_stds)
self.target_means, dtype=np.float32), dtype='float32')
pd_target_stds = paddle.to_tensor(
np.array(
self.target_stds, dtype=np.float32), dtype='float32')
bboxes = bbox_utils.delta2rbox(anchors, bbox_pred, pd_target_means,
pd_target_stds)
mlvl_bboxes.append(bboxes) mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores) mlvl_scores.append(scores)
......
...@@ -17,7 +17,7 @@ import paddle ...@@ -17,7 +17,7 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from ppdet.core.workspace import register from ppdet.core.workspace import register
from ppdet.modeling.bbox_utils import nonempty_bbox, rbox2poly, pd_rbox2poly from ppdet.modeling.bbox_utils import nonempty_bbox, rbox2poly, rbox2poly
try: try:
from collections.abc import Sequence from collections.abc import Sequence
except Exception: except Exception:
...@@ -214,7 +214,7 @@ class FCOSPostProcess(object): ...@@ -214,7 +214,7 @@ class FCOSPostProcess(object):
@register @register
class S2ANetBBoxPostProcess(object): class S2ANetBBoxPostProcess(nn.Layer):
__shared__ = ['num_classes'] __shared__ = ['num_classes']
__inject__ = ['nms'] __inject__ = ['nms']
...@@ -225,41 +225,43 @@ class S2ANetBBoxPostProcess(object): ...@@ -225,41 +225,43 @@ class S2ANetBBoxPostProcess(object):
self.min_bbox_size = min_bbox_size self.min_bbox_size = min_bbox_size
self.nms = nms self.nms = nms
self.origin_shape_list = [] self.origin_shape_list = []
self.fake_pred_cls_score_bbox = paddle.to_tensor(
np.array(
[[-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
dtype='float32'))
self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
def __call__(self, pred_scores, pred_bboxes): def forward(self, pred_scores, pred_bboxes):
""" """
pred_scores : [N, M] score pred_scores : [N, M] score
pred_bboxes : [N, 5] xc, yc, w, h, a pred_bboxes : [N, 5] xc, yc, w, h, a
im_shape : [N, 2] im_shape im_shape : [N, 2] im_shape
scale_factor : [N, 2] scale_factor scale_factor : [N, 2] scale_factor
""" """
pred_ploys = pd_rbox2poly(pred_bboxes) pred_ploys0 = rbox2poly(pred_bboxes)
pred_ploys = paddle.reshape( pred_ploys = paddle.unsqueeze(pred_ploys0, axis=0)
pred_ploys, [1, pred_ploys.shape[0], pred_ploys.shape[1]])
pred_scores = paddle.to_tensor(pred_scores)
# pred_scores [NA, 16] --> [16, NA] # pred_scores [NA, 16] --> [16, NA]
pred_scores = paddle.transpose(pred_scores, [1, 0]) pred_scores0 = paddle.transpose(pred_scores, [1, 0])
pred_scores = paddle.reshape( pred_scores = paddle.unsqueeze(pred_scores0, axis=0)
pred_scores, [1, pred_scores.shape[0], pred_scores.shape[1]])
pred_cls_score_bbox, bbox_num, _ = self.nms(pred_ploys, pred_scores, pred_cls_score_bbox, bbox_num, _ = self.nms(pred_ploys, pred_scores,
self.num_classes) self.num_classes)
# Prevent empty bbox_pred from decode or NMS. # Prevent empty bbox_pred from decode or NMS.
# Bboxes and score before NMS may be empty due to the score threshold. # Bboxes and score before NMS may be empty due to the score threshold.
if pred_cls_score_bbox.shape[0] == 0: if pred_cls_score_bbox.shape[0] <= 0 or pred_cls_score_bbox.shape[
pred_cls_score_bbox = paddle.to_tensor( 1] <= 1:
np.array( pred_cls_score_bbox = self.fake_pred_cls_score_bbox
[[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32')) bbox_num = self.fake_bbox_num
bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
pred_cls_score_bbox = paddle.reshape(pred_cls_score_bbox, [-1, 10])
assert pred_cls_score_bbox.shape[1] == 10
return pred_cls_score_bbox, bbox_num return pred_cls_score_bbox, bbox_num
def get_pred(self, bboxes, bbox_num, im_shape, scale_factor): def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
""" """
Rescale, clip and filter the bbox from the output of NMS to Rescale, clip and filter the bbox from the output of NMS to
get final prediction. get final prediction.
Args: Args:
bboxes(Tensor): bboxes [N, 10] bboxes(Tensor): bboxes [N, 10]
bbox_num(Tensor): bbox_num bbox_num(Tensor): bbox_num
...@@ -270,6 +272,7 @@ class S2ANetBBoxPostProcess(object): ...@@ -270,6 +272,7 @@ class S2ANetBBoxPostProcess(object):
including labels, scores and bboxes. The size of including labels, scores and bboxes. The size of
bboxes are corresponding to the original image. bboxes are corresponding to the original image.
""" """
assert bboxes.shape[1] == 10
origin_shape = paddle.floor(im_shape / scale_factor + 0.5) origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
origin_shape_list = [] origin_shape_list = []
...@@ -292,7 +295,7 @@ class S2ANetBBoxPostProcess(object): ...@@ -292,7 +295,7 @@ class S2ANetBBoxPostProcess(object):
# bboxes: [N, 10], label, score, bbox # bboxes: [N, 10], label, score, bbox
pred_label_score = bboxes[:, 0:2] pred_label_score = bboxes[:, 0:2]
pred_bbox = bboxes[:, 2:10:1] pred_bbox = bboxes[:, 2:]
# rescale bbox to original image # rescale bbox to original image
scaled_bbox = pred_bbox / scale_factor_list scaled_bbox = pred_bbox / scale_factor_list
......
...@@ -16,7 +16,6 @@ import paddle ...@@ -16,7 +16,6 @@ import paddle
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from .target import rpn_anchor_target, generate_proposal_target, generate_mask_target, libra_generate_proposal_target from .target import rpn_anchor_target, generate_proposal_target, generate_mask_target, libra_generate_proposal_target
from ppdet.modeling import bbox_utils
import numpy as np import numpy as np
...@@ -283,13 +282,58 @@ class RBoxAssigner(object): ...@@ -283,13 +282,58 @@ class RBoxAssigner(object):
""" """
if anchors.ndim == 3: if anchors.ndim == 3:
anchors = anchors.reshape(-1, anchor.shape[-1]) anchors = anchors.reshape(-1, anchors.shape[-1])
assert anchors.ndim == 2 assert anchors.ndim == 2
anchor_num = anchors.shape[0] anchor_num = anchors.shape[0]
anchor_valid = np.ones((anchor_num), np.uint8) anchor_valid = np.ones((anchor_num), np.uint8)
anchor_inds = np.arange(anchor_num) anchor_inds = np.arange(anchor_num)
return anchor_inds return anchor_inds
def rbox2delta(self,
proposals,
gt,
means=[0, 0, 0, 0, 0],
stds=[1, 1, 1, 1, 1]):
"""
Args:
proposals: tensor [N, 5]
gt: gt [N, 5]
means: means [5]
stds: stds [5]
Returns:
"""
proposals = proposals.astype(np.float64)
PI = np.pi
gt_widths = gt[..., 2]
gt_heights = gt[..., 3]
gt_angle = gt[..., 4]
proposals_widths = proposals[..., 2]
proposals_heights = proposals[..., 3]
proposals_angle = proposals[..., 4]
coord = gt[..., 0:2] - proposals[..., 0:2]
dx = (np.cos(proposals[..., 4]) * coord[..., 0] +
np.sin(proposals[..., 4]) * coord[..., 1]) / proposals_widths
dy = (-np.sin(proposals[..., 4]) * coord[..., 0] +
np.cos(proposals[..., 4]) * coord[..., 1]) / proposals_heights
dw = np.log(gt_widths / proposals_widths)
dh = np.log(gt_heights / proposals_heights)
da = (gt_angle - proposals_angle)
da = (da + PI / 4) % PI - PI / 4
da /= PI
deltas = np.stack([dx, dy, dw, dh, da], axis=-1)
means = np.array(means, dtype=deltas.dtype)
stds = np.array(stds, dtype=deltas.dtype)
deltas = (deltas - means) / stds
deltas = deltas.astype(np.float32)
return deltas
def assign_anchor(self, def assign_anchor(self,
anchors, anchors,
gt_bboxes, gt_bboxes,
...@@ -405,7 +449,7 @@ class RBoxAssigner(object): ...@@ -405,7 +449,7 @@ class RBoxAssigner(object):
#print('ancho target pos_inds', pos_inds, len(pos_inds)) #print('ancho target pos_inds', pos_inds, len(pos_inds))
pos_sampled_gt_boxes = gt_bboxes[anchor_gt_bbox_inds[pos_inds]] pos_sampled_gt_boxes = gt_bboxes[anchor_gt_bbox_inds[pos_inds]]
if len(pos_inds) > 0: if len(pos_inds) > 0:
pos_bbox_targets = bbox_utils.rbox2delta(pos_sampled_anchors, pos_bbox_targets = self.rbox2delta(pos_sampled_anchors,
pos_sampled_gt_boxes) pos_sampled_gt_boxes)
bbox_targets[pos_inds, :] = pos_bbox_targets bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0 bbox_weights[pos_inds, :] = 1.0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册