diff --git a/dataset/dota_coco/dota_generate_test_result.py b/dataset/dota_coco/dota_generate_test_result.py new file mode 100644 index 0000000000000000000000000000000000000000..00569970a8b80e5467cc9ded2f0cf09253e786c1 --- /dev/null +++ b/dataset/dota_coco/dota_generate_test_result.py @@ -0,0 +1,255 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import os.path as osp +import re +import json +import glob +import cv2 +import numpy as np +from multiprocessing import Pool +from functools import partial +from shapely.geometry import Polygon +import argparse + +nms_thresh = 0.1 + +class_name_15 = [ + 'plane', 'baseball-diamond', 'bridge', 'ground-track-field', + 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court', + 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout', + 'harbor', 'swimming-pool', 'helicopter' +] + +class_name_16 = [ + 'plane', 'baseball-diamond', 'bridge', 'ground-track-field', + 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court', + 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout', + 'harbor', 'swimming-pool', 'helicopter', 'container-crane' +] + + +def rbox_iou(g, p): + """ + iou of rbox + """ + g = np.array(g) + p = np.array(p) + g = Polygon(g[:8].reshape((4, 2))) + p = Polygon(p[:8].reshape((4, 2))) + g = g.buffer(0) + p = p.buffer(0) + if not g.is_valid or not p.is_valid: + return 0 + inter = Polygon(g).intersection(Polygon(p)).area + union = g.area + p.area - inter + if union == 0: + return 0 + else: + return inter / union + + +def py_cpu_nms_poly_fast(dets, thresh): + """ + Args: + dets: pred results + thresh: nms threshold + + Returns: index of keep + """ + obbs = dets[:, 0:-1] + x1 = np.min(obbs[:, 0::2], axis=1) + y1 = np.min(obbs[:, 1::2], axis=1) + x2 = np.max(obbs[:, 0::2], axis=1) + y2 = np.max(obbs[:, 1::2], axis=1) + scores = dets[:, 8] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + + polys = [] + for i in range(len(dets)): + tm_polygon = [dets[i][0], dets[i][1], + dets[i][2], dets[i][3], + dets[i][4], dets[i][5], + dets[i][6], dets[i][7]] + polys.append(tm_polygon) + polys = np.array(polys) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + ovr = [] + i = order[0] + keep.append(i) + + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + hbb_inter = w * h + hbb_ovr = hbb_inter / (areas[i] + areas[order[1:]] - hbb_inter) + # h_keep_inds = np.where(hbb_ovr == 0)[0] + h_inds = np.where(hbb_ovr > 0)[0] + tmp_order = order[h_inds + 1] + for j in range(tmp_order.size): + iou = rbox_iou(polys[i], polys[tmp_order[j]]) + hbb_ovr[h_inds[j]] = iou + # ovr.append(iou) + # ovr_index.append(tmp_order[j]) + + try: + if math.isnan(ovr[0]): + pdb.set_trace() + except: + pass + inds = np.where(hbb_ovr <= thresh)[0] + + order = order[inds + 1] + return keep + + +def poly2origpoly(poly, x, y, rate): + origpoly = [] + for i in range(int(len(poly)/2)): + tmp_x = float(poly[i * 2] + x) / float(rate) + tmp_y = float(poly[i * 2 + 1] + y) / float(rate) + origpoly.append(tmp_x) + origpoly.append(tmp_y) + return origpoly + + +def nmsbynamedict(nameboxdict, nms, thresh): + """ + Args: + nameboxdict: nameboxdict + nms: nms + thresh: nms threshold + + Returns: nms result as dict + """ + nameboxnmsdict = {x: [] for x in nameboxdict} + for imgname in nameboxdict: + keep = nms(np.array(nameboxdict[imgname]), thresh) + outdets = [] + for index in keep: + outdets.append(nameboxdict[imgname][index]) + nameboxnmsdict[imgname] = outdets + return nameboxnmsdict + + +def merge_single(output_dir, nms, pred_class_lst): + """ + Args: + output_dir: output_dir + nms: nms + pred_class_lst: pred_class_lst + class_name: class_name + + Returns: + + """ + class_name, pred_bbox_list = pred_class_lst + nameboxdict = {} + for line in pred_bbox_list: + splitline = line.split(' ') + subname = splitline[0] + splitname = subname.split('__') + oriname = splitname[0] + pattern1 = re.compile(r'__\d+___\d+') + x_y = re.findall(pattern1, subname) + x_y_2 = re.findall(r'\d+', x_y[0]) + x, y = int(x_y_2[0]), int(x_y_2[1]) + + pattern2 = re.compile(r'__([\d+\.]+)__\d+___') + + rate = re.findall(pattern2, subname)[0] + + confidence = splitline[1] + poly = list(map(float, splitline[2:])) + origpoly = poly2origpoly(poly, x, y, rate) + det = origpoly + det.append(confidence) + det = list(map(float, det)) + if (oriname not in nameboxdict): + nameboxdict[oriname] = [] + nameboxdict[oriname].append(det) + nameboxnmsdict = nmsbynamedict(nameboxdict, nms, nms_thresh) + + # write result + dstname = os.path.join(output_dir, class_name + '.txt') + with open(dstname, 'w') as f_out: + for imgname in nameboxnmsdict: + for det in nameboxnmsdict[imgname]: + confidence = det[-1] + bbox = det[0:-1] + outline = imgname + ' ' + str(confidence) + ' ' + ' '.join(map(str, bbox)) + f_out.write(outline + '\n') + + +def dota_generate_test_result(pred_txt_dir, output_dir='output', dota_version='v1.0'): + """ + pred_txt_dir: dir of pred txt + output_dir: dir of output + dota_version: dota_version v1.0 or v1.5 or v2.0 + """ + pred_txt_list = glob.glob("{}/*.txt".format(pred_txt_dir)) + + # step1: summary pred bbox + pred_classes = {} + class_lst = class_name_15 if dota_version == 'v1.0' else class_name_16 + for class_name in class_lst: + pred_classes[class_name] = [] + + for current_txt in pred_txt_list: + img_id = os.path.split(current_txt)[1] + img_id = img_id.split('.txt')[0] + with open(current_txt) as f: + res = f.readlines() + for item in res: + item = item.split(' ') + pred_class = item[0] + item[0] = img_id + pred_bbox = ' '.join(item) + pred_classes[pred_class].append(pred_bbox) + + pred_classes_lst = [] + for class_name in pred_classes.keys(): + print('class_name: {}, count: {}'.format(class_name, len(pred_classes[class_name]))) + pred_classes_lst.append((class_name, pred_classes[class_name])) + + # step2: merge + pool = Pool(len(class_lst)) + nms = py_cpu_nms_poly_fast + mergesingle_fn = partial(merge_single, output_dir, nms) + pool.map(mergesingle_fn, pred_classes_lst) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='dota anno to coco') + parser.add_argument('--pred_txt_dir', help='path of pred txt dir') + parser.add_argument('--output_dir', help='path of output dir', default='output') + parser.add_argument( + '--dota_version', + help='dota_version, v1.0 or v1.5 or v2.0', + type=str, + default='v1.0') + + args = parser.parse_args() + + # process + dota_generate_test_result(args.pred_txt_dir, args.output_dir, args.dota_version) + print('done!') diff --git a/dataset/dota_coco/dota_to_coco.py b/dataset/dota_coco/dota_to_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..dbeefc58a8d88908ed71bdb443e8451539d4458a --- /dev/null +++ b/dataset/dota_coco/dota_to_coco.py @@ -0,0 +1,166 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os.path as osp +import json +import glob +import cv2 +import numpy as np +from PIL import Image +import logging +import argparse + +# add python path of PadleDetection to sys.path +parent_path = osp.abspath(osp.join(__file__, *(['..'] * 3))) +if parent_path not in sys.path: + sys.path.append(parent_path) + +from ppdet.modeling.bbox_utils import poly_to_rbox +from ppdet.utils.logger import setup_logger +logger = setup_logger(__name__) + +class_name_15 = [ + 'plane', 'baseball-diamond', 'bridge', 'ground-track-field', + 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court', + 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout', + 'harbor', 'swimming-pool', 'helicopter' +] + +class_name_16 = [ + 'plane', 'baseball-diamond', 'bridge', 'ground-track-field', + 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court', + 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout', + 'harbor', 'swimming-pool', 'helicopter', 'container-crane' +] + + +def dota_2_coco(image_dir, + txt_dir, + json_path='dota_coco.json', + is_obb=True, + dota_version='v1.0'): + """ + image_dir: image dir + txt_dir: txt label dir + json_path: json save path + is_obb: is obb or not + dota_version: dota_version v1.0 or v1.5 or v2.0 + """ + + img_lists = glob.glob("{}/*.png".format(image_dir)) + data_dict = {} + data_dict['images'] = [] + data_dict['categories'] = [] + data_dict['annotations'] = [] + inst_count = 0 + + # categories + class_name2id = {} + if dota_version == 'v1.0': + for class_id, class_name in enumerate(class_name_15): + class_name2id[class_name] = class_id + 1 + single_cat = { + 'id': class_id + 1, + 'name': class_name, + 'supercategory': class_name + } + data_dict['categories'].append(single_cat) + + for image_id, img_path in enumerate(img_lists): + single_image = {} + basename = osp.basename(img_path) + single_image['file_name'] = basename + single_image['id'] = image_id + img = cv2.imread(img_path) + height, width, _ = img.shape + single_image['width'] = width + single_image['height'] = height + # add image + data_dict['images'].append(single_image) + + # annotations + anno_txt_path = osp.join(txt_dir, osp.splitext(basename)[0] + '.txt') + if not osp.exists(anno_txt_path): + logger.warn('path of {} not exists'.format(anno_txt_path)) + + for line in open(anno_txt_path): + line = line.strip() + # skip + if line.find('imagesource') >= 0 or line.find('gsd') >= 0: + continue + + # x1,y1,x2,y2,x3,y3,x4,y4 class_name, is_different + single_obj_anno = line.split(' ') + assert len(single_obj_anno) == 10 + single_obj_poly = [float(e) for e in single_obj_anno[0:8]] + single_obj_classname = single_obj_anno[8] + single_obj_different = int(single_obj_anno[9]) + + single_obj = {} + + single_obj['category_id'] = class_name2id[single_obj_classname] + single_obj['segmentation'] = [] + single_obj['segmentation'].append(single_obj_poly) + single_obj['iscrowd'] = 0 + + # rbox or bbox + if is_obb: + polys = [single_obj_poly] + rboxs = poly_to_rbox(polys) + rbox = rboxs[0].tolist() + single_obj['bbox'] = rbox + single_obj['area'] = rbox[2] * rbox[3] + else: + xmin, ymin, xmax, ymax = min(single_obj_poly[0::2]), min(single_obj_poly[1::2]), \ + max(single_obj_poly[0::2]), max(single_obj_poly[1::2]) + + width, height = xmax - xmin, ymax - ymin + single_obj['bbox'] = xmin, ymin, width, height + single_obj['area'] = width * height + + single_obj['image_id'] = image_id + data_dict['annotations'].append(single_obj) + single_obj['id'] = inst_count + inst_count = inst_count + 1 + # add annotation + data_dict['annotations'].append(single_obj) + + with open(json_path, 'w') as f: + json.dump(data_dict, f) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='dota anno to coco') + parser.add_argument('--images_dir', help='path_to_images') + parser.add_argument('--label_dir', help='path_to_labelTxt', type=str) + parser.add_argument( + '--json_path', + help='save json path', + type=str, + default='dota_coco.json') + parser.add_argument( + '--is_obb', help='is_obb or not', type=bool, default=True) + parser.add_argument( + '--dota_version', + help='dota_version, v1.0 or v1.5 or v2.0', + type=str, + default='v1.0') + + args = parser.parse_args() + + # process + dota_2_coco(args.images_dir, args.label_dir, args.json_path, args.is_obb, + args.dota_version) + print('done!') diff --git a/ppdet/modeling/architectures/s2anet.py b/ppdet/modeling/architectures/s2anet.py index 72e9e820adcf230c5dd4a0d6c51c0496779e424a..0c8ae545299979a9f8008ed3a4bc64068cf2aae4 100644 --- a/ppdet/modeling/architectures/s2anet.py +++ b/ppdet/modeling/architectures/s2anet.py @@ -83,12 +83,13 @@ class S2ANet(BaseArch): nms_pre = self.s2anet_bbox_post_process.nms_pre pred_scores, pred_bboxes = self.s2anet_head.get_prediction(nms_pre) - # post_process - pred_cls_score_bbox, bbox_num, index = self.s2anet_bbox_post_process.get_prediction( - pred_scores, pred_bboxes, im_shape, scale_factor) - + pred_bboxes, bbox_num = self.s2anet_bbox_post_process(pred_scores, + pred_bboxes) + # rescale the prediction back to origin image + pred_bboxes = self.s2anet_bbox_post_process.get_pred( + pred_bboxes, bbox_num, im_shape, scale_factor) # output - output = {'bbox': pred_cls_score_bbox, 'bbox_num': bbox_num} + output = {'bbox': pred_bboxes, 'bbox_num': bbox_num} return output def get_loss(self, ): diff --git a/ppdet/modeling/bbox_utils.py b/ppdet/modeling/bbox_utils.py index c77a5aec50406e257acef33235c6a2fce544c5cb..b14e26461d27c0757937945ba87159f2fe48dbfd 100644 --- a/ppdet/modeling/bbox_utils.py +++ b/ppdet/modeling/bbox_utils.py @@ -299,8 +299,6 @@ def delta2rbox(Rrois, :param wh_ratio_clip: :return: """ - means = paddle.to_tensor(means) - stds = paddle.to_tensor(stds) deltas = paddle.reshape(deltas, [-1, deltas.shape[-1]]) denorm_deltas = deltas * stds + means @@ -391,15 +389,12 @@ def bbox_decode(bbox_preds, return: bboxes: [N,H,W,5] """ - means = paddle.to_tensor(means) - stds = paddle.to_tensor(stds) num_imgs, H, W, _ = bbox_preds.shape bboxes_list = [] for img_id in range(num_imgs): bbox_pred = bbox_preds[img_id] # bbox_pred.shape=[5,H,W] bbox_delta = bbox_pred - anchors = paddle.to_tensor(anchors) bboxes = delta2rbox( anchors, bbox_delta, means, stds, wh_ratio_clip=1e-6) bboxes = paddle.reshape(bboxes, [H, W, 5]) @@ -512,8 +507,14 @@ def rbox2poly(rrects): poly:[x0,y0,x1,y1,x2,y2,x3,y3] """ polys = [] - for rrect in rrects: - x_ctr, y_ctr, width, height, angle = rrect[:5] + rrects = rrects.numpy() + for i in range(rrects.shape[0]): + rrect = rrects[i] + x_ctr = rrect[0] + y_ctr = rrect[1] + width = rrect[2] + height = rrect[3] + angle = rrect[4] tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2 rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]]) R = np.array([[np.cos(angle), -np.sin(angle)], @@ -526,3 +527,46 @@ def rbox2poly(rrects): polys.append(poly) polys = np.array(polys) return polys + + +def pd_rbox2poly(rrects): + """ + rrect:[x_ctr,y_ctr,w,h,angle] + to + poly:[x0,y0,x1,y1,x2,y2,x3,y3] + """ + N = rrects.shape[0] + + x_ctr = rrects[:, 0] + y_ctr = rrects[:, 1] + width = rrects[:, 2] + height = rrects[:, 3] + angle = rrects[:, 4] + + tl_x, tl_y, br_x, br_y = -width * 0.5, -height * 0.5, width * 0.5, height * 0.5 + + normal_rects = paddle.stack( + [tl_x, br_x, br_x, tl_x, tl_y, tl_y, br_y, br_y], axis=0) + normal_rects = paddle.reshape(normal_rects, [2, 4, N]) + normal_rects = paddle.transpose(normal_rects, [2, 0, 1]) + + sin, cos = paddle.sin(angle), paddle.cos(angle) + # M.shape=[N,2,2] + M = paddle.stack([cos, -sin, sin, cos], axis=0) + M = paddle.reshape(M, [2, 2, N]) + M = paddle.transpose(M, [2, 0, 1]) + + # polys:[N,8] + polys = paddle.matmul(M, normal_rects) + polys = paddle.transpose(polys, [2, 1, 0]) + polys = paddle.reshape(polys, [-1, N]) + polys = paddle.transpose(polys, [1, 0]) + polys[:, 0] += x_ctr + polys[:, 2] += x_ctr + polys[:, 4] += x_ctr + polys[:, 6] += x_ctr + polys[:, 1] += y_ctr + polys[:, 3] += y_ctr + polys[:, 5] += y_ctr + polys[:, 7] += y_ctr + return polys diff --git a/ppdet/modeling/heads/s2anet_head.py b/ppdet/modeling/heads/s2anet_head.py index 12e0c3144622d2cdb6c3f70f7cdf9d2c4c2483f9..ca8616bcc91bd14cb7cdb6b473354400f8bebc94 100644 --- a/ppdet/modeling/heads/s2anet_head.py +++ b/ppdet/modeling/heads/s2anet_head.py @@ -246,7 +246,8 @@ class S2ANetHead(nn.Layer): align_conv_size=3, use_sigmoid_cls=True, anchor_assign=RBoxAssigner().__dict__, - reg_loss_weight=[1.0, 1.0, 1.0, 1.0, 1.0]): + reg_loss_weight=[1.0, 1.0, 1.0, 1.0, 1.0], + cls_loss_weight=[1.0, 1.0]): super(S2ANetHead, self).__init__() self.stacked_convs = stacked_convs self.feat_in = feat_in @@ -267,6 +268,7 @@ class S2ANetHead(nn.Layer): self.sampling = False self.anchor_assign = anchor_assign self.reg_loss_weight = reg_loss_weight + self.cls_loss_weight = cls_loss_weight self.s2anet_head_out = None @@ -453,11 +455,19 @@ class S2ANetHead(nn.Layer): init_anchors = bbox_utils.rect2rbox(init_anchors) self.base_anchors[(i, featmap_size[0])] = init_anchors - #fam_reg1 = fam_reg - #fam_reg1.stop_gradient = True + fam_reg1 = fam_reg.clone() + fam_reg1.stop_gradient = True + pd_target_means = paddle.to_tensor( + np.array( + self.target_means, dtype=np.float32), dtype='float32') + pd_target_stds = paddle.to_tensor( + np.array( + self.target_stds, dtype=np.float32), dtype='float32') + pd_init_anchors = paddle.to_tensor( + np.array( + init_anchors, dtype=np.float32), dtype='float32') refine_anchor = bbox_utils.bbox_decode( - fam_reg.detach(), init_anchors, self.target_means, - self.target_stds) + fam_reg1, pd_init_anchors, pd_target_means, pd_target_stds) self.refine_anchor_list.append(refine_anchor) @@ -605,7 +615,9 @@ class S2ANetHead(nn.Layer): fam_bbox_losses.append(fam_bbox_total) fam_cls_loss = paddle.add_n(fam_cls_losses) - fam_cls_loss = fam_cls_loss * 2.0 + fam_cls_loss_weight = paddle.to_tensor( + self.cls_loss_weight[0], dtype='float32', stop_gradient=True) + fam_cls_loss = fam_cls_loss * fam_cls_loss_weight fam_reg_loss = paddle.add_n(fam_bbox_losses) return fam_cls_loss, fam_reg_loss @@ -686,7 +698,9 @@ class S2ANetHead(nn.Layer): odm_bbox_losses.append(odm_bbox_total) odm_cls_loss = paddle.add_n(odm_cls_losses) - odm_cls_loss = odm_cls_loss * 2.0 + odm_cls_loss_weight = paddle.to_tensor( + self.cls_loss_weight[1], dtype='float32', stop_gradient=True) + odm_cls_loss = odm_cls_loss * odm_cls_loss_weight odm_reg_loss = paddle.add_n(odm_bbox_losses) return odm_cls_loss, odm_reg_loss @@ -852,10 +866,14 @@ class S2ANetHead(nn.Layer): bbox_pred = paddle.gather(bbox_pred, topk_inds) scores = paddle.gather(scores, topk_inds) - target_means = (.0, .0, .0, .0, .0) - target_stds = (1.0, 1.0, 1.0, 1.0, 1.0) - bboxes = bbox_utils.delta2rbox(anchors, bbox_pred, target_means, - target_stds) + pd_target_means = paddle.to_tensor( + np.array( + self.target_means, dtype=np.float32), dtype='float32') + pd_target_stds = paddle.to_tensor( + np.array( + self.target_stds, dtype=np.float32), dtype='float32') + bboxes = bbox_utils.delta2rbox(anchors, bbox_pred, pd_target_means, + pd_target_stds) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index 429372a77450f1bfba421453eeec24545249d2a3..a41deac7f1a19e9bd170185be1652d754eae2068 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -17,7 +17,7 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F from ppdet.core.workspace import register -from ppdet.modeling.bbox_utils import nonempty_bbox, rbox2poly +from ppdet.modeling.bbox_utils import nonempty_bbox, rbox2poly, pd_rbox2poly from . import ops try: from collections.abc import Sequence @@ -111,6 +111,8 @@ class BBoxPostProcess(object): pred_score = bboxes[:, 1:2] pred_bbox = bboxes[:, 2:] # rescale bbox to original image + print('pred_bbox', pred_bbox.shape, 'scale_factor_list', + scale_factor_list.shape) scaled_bbox = pred_bbox / scale_factor_list origin_h = self.origin_shape_list[:, 0] origin_w = self.origin_shape_list[:, 1] @@ -222,25 +224,25 @@ class FCOSPostProcess(object): @register class S2ANetBBoxPostProcess(object): + __shared__ = ['num_classes'] __inject__ = ['nms'] - def __init__(self, nms_pre=2000, min_bbox_size=0, nms=None): + def __init__(self, num_classes=15, nms_pre=2000, min_bbox_size=0, nms=None): super(S2ANetBBoxPostProcess, self).__init__() + self.num_classes = num_classes self.nms_pre = nms_pre self.min_bbox_size = min_bbox_size self.nms = nms self.origin_shape_list = [] - def get_prediction(self, pred_scores, pred_bboxes, im_shape, scale_factor): + def __call__(self, pred_scores, pred_bboxes): """ pred_scores : [N, M] score pred_bboxes : [N, 5] xc, yc, w, h, a im_shape : [N, 2] im_shape scale_factor : [N, 2] scale_factor """ - # TODO: support bs>1 - pred_ploys = rbox2poly(pred_bboxes.numpy()) - pred_ploys = paddle.to_tensor(pred_ploys) + pred_ploys = pd_rbox2poly(pred_bboxes) pred_ploys = paddle.reshape( pred_ploys, [1, pred_ploys.shape[0], pred_ploys.shape[1]]) @@ -249,32 +251,26 @@ class S2ANetBBoxPostProcess(object): pred_scores = paddle.transpose(pred_scores, [1, 0]) pred_scores = paddle.reshape( pred_scores, [1, pred_scores.shape[0], pred_scores.shape[1]]) - pred_cls_score_bbox, bbox_num, index = self.nms(pred_ploys, pred_scores) - # post process scale - # result [n, 10] - if bbox_num > 0: - pred_bbox, bbox_num = self.post_process(pred_cls_score_bbox[:, 2:], - bbox_num, im_shape[0], - scale_factor[0]) + pred_cls_score_bbox, bbox_num, _ = self.nms(pred_ploys, pred_scores, + self.num_classes) - pred_cls_score_bbox = paddle.concat( - [pred_cls_score_bbox[:, 0:2], pred_bbox], axis=1) - else: + # Prevent empty bbox_pred from decode or NMS. + # Bboxes and score before NMS may be empty due to the score threshold. + if pred_cls_score_bbox.shape[0] == 0: pred_cls_score_bbox = paddle.to_tensor( np.array( - [[-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], - dtype='float32')) + [[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32')) bbox_num = paddle.to_tensor(np.array([1], dtype='int32')) - return pred_cls_score_bbox, bbox_num, index + return pred_cls_score_bbox, bbox_num - def post_process(self, bboxes, bbox_num, im_shape, scale_factor): + def get_pred(self, bboxes, bbox_num, im_shape, scale_factor): """ Rescale, clip and filter the bbox from the output of NMS to get final prediction. Args: - bboxes(Tensor): bboxes [N, 8] + bboxes(Tensor): bboxes [N, 10] bbox_num(Tensor): bbox_num im_shape(Tensor): [1 2] scale_factor(Tensor): [1 2] @@ -283,14 +279,46 @@ class S2ANetBBoxPostProcess(object): including labels, scores and bboxes. The size of bboxes are corresponding to the original image. """ - + print('im_shape', im_shape, 'scale_factor', scale_factor) origin_shape = paddle.floor(im_shape / scale_factor + 0.5) - origin_h = origin_shape[0] - origin_w = origin_shape[1] + origin_shape_list = [] + scale_factor_list = [] + # scale_factor: scale_y, scale_x + for i in range(bbox_num.shape[0]): + expand_shape = paddle.expand(origin_shape[i:i + 1, :], + [bbox_num[i], 2]) + scale_y, scale_x = scale_factor[i][0], scale_factor[i][1] + scale = paddle.concat([ + scale_x, scale_y, scale_x, scale_y, scale_x, scale_y, scale_x, + scale_y + ]) + expand_scale = paddle.expand(scale, [bbox_num[i], 8]) + origin_shape_list.append(expand_shape) + scale_factor_list.append(expand_scale) + + origin_shape_list = paddle.concat(origin_shape_list) + scale_factor_list = paddle.concat(scale_factor_list) + + # bboxes: [N, 10], label, score, bbox + print('bboxes', bboxes.shape) + pred_label_score = bboxes[:, 0:2] + print('pred_label_score', pred_label_score.shape) + pred_bbox = bboxes[:, 2:10:1] + print('pred_bbox', pred_bbox.shape) + + # rescale bbox to original image + scaled_bbox = pred_bbox / scale_factor_list + origin_h = origin_shape_list[:, 0] + origin_w = origin_shape_list[:, 1] + print('scaled_bbox', bboxes.shape) - bboxes[:, 0::2] = bboxes[:, 0::2] / scale_factor[0] - bboxes[:, 1::2] = bboxes[:, 1::2] / scale_factor[1] + bboxes = scaled_bbox + #print('bboxes', bboxes.shape, 'scale_factor', scale_factor.shape) + #print('bboxes[:, 0::2]', bboxes[:, 0::2].shape) + #print('scale_factor[0]', scale_factor) + #bboxes[:, 0::2] = bboxes[:, 0::2] / scale_factor[:, 0] + #bboxes[:, 1::2] = bboxes[:, 1::2] / scale_factor[:, 1] zeros = paddle.zeros_like(origin_h) x1 = paddle.maximum(paddle.minimum(bboxes[:, 0], origin_w - 1), zeros) @@ -301,6 +329,6 @@ class S2ANetBBoxPostProcess(object): y3 = paddle.maximum(paddle.minimum(bboxes[:, 5], origin_h - 1), zeros) x4 = paddle.maximum(paddle.minimum(bboxes[:, 6], origin_w - 1), zeros) y4 = paddle.maximum(paddle.minimum(bboxes[:, 7], origin_h - 1), zeros) - bbox = paddle.stack([x1, y1, x2, y2, x3, y3, x4, y4], axis=-1) - bboxes = (bbox, bbox_num) - return bboxes + pred_bbox = paddle.stack([x1, y1, x2, y2, x3, y3, x4, y4], axis=-1) + pred_result = paddle.concat([pred_label_score, pred_bbox], axis=1) + return pred_result diff --git a/static/ppdet/utils/export_utils.py b/static/ppdet/utils/export_utils.py index 3579ddb495f8e6055a512f9d8fe32897ca7abe22..a65fd6e2799edbe0cea361844f550afa8c8515a3 100644 --- a/static/ppdet/utils/export_utils.py +++ b/static/ppdet/utils/export_utils.py @@ -34,6 +34,7 @@ TRT_MIN_SUBGRAPH = { 'SSD': 3, 'RCNN': 40, 'RetinaNet': 40, + 'S2ANet': 40, 'EfficientDet': 40, 'Face': 3, 'TTFNet': 3, @@ -43,6 +44,7 @@ TRT_MIN_SUBGRAPH = { RESIZE_SCALE_SET = { 'RCNN', 'RetinaNet', + 'S2ANet', 'FCOS', 'SOLOv2', }