提交 0d004626 编写于 作者: F FlyingQianMM

support addding background images for detection training

上级 fd42b2f6
......@@ -128,7 +128,6 @@ class CocoDetection(VOCDetection):
coco_rec = (im_info, label_info)
self.file_list.append([im_fname, coco_rec])
if not len(self.file_list) > 0:
raise Exception('not found any coco record in %s' % (ann_file))
logging.info("{} samples in file {}".format(
......
......@@ -14,6 +14,7 @@
from __future__ import absolute_import
import copy
import os
import os.path as osp
import random
import numpy as np
......@@ -170,6 +171,43 @@ class VOCDetection(Dataset):
self.coco_gt.dataset = annotations
self.coco_gt.createIndex()
def append_backgrounds(self, image_dir):
import cv2
if not osp.exists(image_dir):
raise Exception("{} background images directory does not exist.".format(image_dir))
image_list = os.listdir(image_dir)
max_img_id = max(self.coco_gt.getImgIds())
for image in image_list:
if not is_pic(image):
continue
# False ground truth
gt_bbox = np.array([[0, 0, 1e-05, 1e-05]], dtype=np.float32)
gt_class = np.array([[0]], dtype=np.int32)
gt_score = np.ones((1, 1), dtype=np.float32)
is_crowd = np.array([[0]], dtype=np.int32)
difficult = np.zeros((1, 1), dtype=np.int32)
gt_poly = [[[0, 0, 0, 1e-05, 1e-05, 1e-05, 1e-05, 0]]]
max_img_id += 1
im_fname = osp.join(image_dir, image)
img_data = cv2.imread(im_fname)
im_h, im_w, im_c = img_data.shape
im_info = {
'im_id': np.array([max_img_id]).astype('int32'),
'image_shape': np.array([im_h, im_w]).astype('int32'),
}
label_info = {
'is_crowd': is_crowd,
'gt_class': gt_class,
'gt_bbox': gt_bbox,
'gt_score': gt_score,
'difficult': difficult,
'gt_poly': gt_poly
}
coco_rec = (im_info, label_info)
self.file_list.append([im_fname, coco_rec])
self.num_samples = len(self.file_list)
def iterator(self):
self._epoch += 1
self._pos = 0
......
......@@ -204,7 +204,7 @@ class MaskRCNN(object):
bg_thresh_hi=self.bg_thresh_hi,
bg_thresh_lo=self.bg_thresh_lo,
bbox_reg_weights=self.bbox_reg_weights,
calass_nums=self.num_classes,
class_nums=self.num_classes,
use_random=self.rpn_head.use_random)
rois = outputs[0]
......
......@@ -722,22 +722,38 @@ class MixupImage(DetTransform):
'Becasuse gt_bbox/gt_class/gt_score is not in label_info!')
gt_bbox1 = label_info['gt_bbox']
gt_bbox2 = im_info['mixup'][2]['gt_bbox']
gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
gt_class1 = label_info['gt_class']
gt_class2 = im_info['mixup'][2]['gt_class']
gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
gt_score1 = label_info['gt_score']
gt_score2 = im_info['mixup'][2]['gt_score']
gt_score = np.concatenate(
(gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
if 'gt_poly' in label_info:
gt_poly1 = label_info['gt_poly']
gt_poly2 = im_info['mixup'][2]['gt_poly']
label_info['gt_poly'] = gt_poly1 + gt_poly2
is_crowd1 = label_info['is_crowd']
is_crowd2 = im_info['mixup'][2]['is_crowd']
is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0)
if 0 not in gt_class1 and 0 not in gt_class2:
gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
gt_score = np.concatenate(
(gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
if 'gt_poly' in label_info:
label_info['gt_poly'] = gt_poly1 + gt_poly2
is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0)
elif 0 in gt_class1:
gt_bbox = gt_bbox2
gt_class = gt_class2
gt_score = gt_score2 * (1. - factor)
if 'gt_poly' in label_info:
label_info['gt_poly'] = gt_poly2
is_crowd = is_crowd2
else:
gt_bbox = gt_bbox1
gt_class = gt_class1
gt_score = gt_score1 * factor
if 'gt_poly' in label_info:
label_info['gt_poly'] = gt_poly1
is_crowd = is_crowd1
label_info['gt_bbox'] = gt_bbox
label_info['gt_score'] = gt_score
label_info['gt_class'] = gt_class
......@@ -809,6 +825,8 @@ class RandomExpand(DetTransform):
if np.random.uniform(0., 1.) < self.prob:
return (im, im_info, label_info)
if 'gt_class' in label_info and 0 in label_info['gt_class']:
return (im, im_info, label_info)
image_shape = im_info['image_shape']
height = int(image_shape[0])
width = int(image_shape[1])
......@@ -904,6 +922,8 @@ class RandomCrop(DetTransform):
if len(label_info['gt_bbox']) == 0:
return (im, im_info, label_info)
if 'gt_class' in label_info and 0 in label_info['gt_class']:
return (im, im_info, label_info)
image_shape = im_info['image_shape']
w = image_shape[1]
......@@ -1199,9 +1219,10 @@ class ArrangeYOLOv3(DetTransform):
if gt_num > 0:
label_info['gt_class'][:gt_num, 0] = label_info[
'gt_class'][:gt_num, 0] - 1
gt_bbox[:gt_num, :] = label_info['gt_bbox'][:gt_num, :]
gt_class[:gt_num] = label_info['gt_class'][:gt_num, 0]
gt_score[:gt_num] = label_info['gt_score'][:gt_num, 0]
if -1 not in label_info['gt_class']:
gt_bbox[:gt_num, :] = label_info['gt_bbox'][:gt_num, :]
gt_class[:gt_num] = label_info['gt_class'][:gt_num, 0]
gt_score[:gt_num] = label_info['gt_score'][:gt_num, 0]
# parse [x1, y1, x2, y2] to [x, y, w, h]
gt_bbox[:, 2:4] = gt_bbox[:, 2:4] - gt_bbox[:, :2]
gt_bbox[:, :2] = gt_bbox[:, :2] + gt_bbox[:, 2:4] / 2.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册