提交 bb7ae5d9 编写于 作者: D dengkaipeng

fit for darknet

上级 ca44df94
......@@ -72,15 +72,15 @@ _C.pixel_stds = [0.229, 0.224, 0.225]
_C.learning_rate = 0.001
# maximum number of iterations
_C.max_iter = 500200
_C.max_iter = 500000
# warm up to learning rate
_C.warm_up_iter = 4000
_C.warm_up_factor = 0.
# lr steps_with_decay
_C.lr_steps = [400000, 450000]
_C.lr_gamma = 0.1
_C.lr_steps = [300000, 400000, 450000]
_C.lr_gamma = [0.2, 0.5, 0.1]
# L2 regularization hyperparameter
_C.weight_decay = 0.0005
......
......@@ -105,7 +105,7 @@ def random_flip(img, gtboxes, thresh=0.5):
gtboxes[:, 0] = 1.0 - gtboxes[:, 0]
return img, gtboxes
def random_interp(img, size):
def random_interp(img, size, interp=None):
interp_method = [
cv2.INTER_NEAREST,
cv2.INTER_LINEAR,
......@@ -113,7 +113,8 @@ def random_interp(img, size):
cv2.INTER_CUBIC,
cv2.INTER_LANCZOS4,
]
interp = interp_method[random.randint(0, len(interp_method) - 1)]
if not interp or interp not in interp_method:
interp = interp_method[random.randint(0, len(interp_method) - 1)]
h, w, _ = img.shape
im_scale_x = size / float(w)
im_scale_y = size / float(h)
......@@ -151,6 +152,13 @@ def random_expand(img, gtboxes, max_ratio=4., fill=None, keep_ratio=True, thresh
return out_img.astype('uint8'), gtboxes
def shuffle_gtbox(gtbox, gtlabel, gtscore):
gt = np.concatenate([gtbox, gtlabel[:, np.newaxis], gtscore[:, np.newaxis]], axis=1)
idx = np.arange(gt.shape[1])
np.random.shuffle(idx)
gt = gt[idx, :]
return gt[:, :4], gt[:, 4], gt[:, 5]
def image_mixup(img1, gtboxes1, gtlabels1, gtscores1, img2, gtboxes2, gtlabels2, gtscores2):
factor = np.random.beta(1.5, 1.5)
factor = max(0.0, min(1.0, factor))
......@@ -201,8 +209,9 @@ def image_augment(img, gtboxes, gtlabels, gtscores, size, means=None):
img = random_distort(img)
img, gtboxes = random_expand(img, gtboxes, fill=means)
img, gtboxes, gtlabels, gtscores = random_crop(img, gtboxes, gtlabels, gtscores)
img = random_interp(img, size)
img = random_interp(img, size, cv2.INTER_LINEAR)
img, gtboxes = random_flip(img, gtboxes)
gtboxes, gtlabels, gtscores = shuffle_gtbox(gtboxes, gtlabels, gtscores)
return img.astype('float32'), gtboxes.astype('float32'), \
gtlabels.astype('int32'), gtscores.astype('float32')
......
......@@ -38,23 +38,23 @@ class DataSetReader(object):
self.has_parsed_categpry = False
def _parse_dataset_dir(self, mode):
# cfg.data_dir = "dataset/coco"
# cfg.train_file_list = 'annotations/instances_val2017.json'
# cfg.train_data_dir = 'val2017'
# cfg.dataset = "coco2017"
if 'coco2014' in cfg.dataset:
cfg.train_file_list = 'annotations/instances_train2014.json'
cfg.train_data_dir = 'train2014'
cfg.val_file_list = 'annotations/instances_val2014.json'
cfg.val_data_dir = 'val2014'
elif 'coco2017' in cfg.dataset:
cfg.train_file_list = 'annotations/instances_train2017.json'
cfg.train_data_dir = 'train2017'
cfg.val_file_list = 'annotations/instances_val2017.json'
cfg.val_data_dir = 'val2017'
else:
raise NotImplementedError('Dataset {} not supported'.format(
cfg.dataset))
cfg.data_dir = "dataset/coco"
cfg.train_file_list = 'annotations/instances_val2017.json'
cfg.train_data_dir = 'val2017'
cfg.dataset = "coco2017"
# if 'coco2014' in cfg.dataset:
# cfg.train_file_list = 'annotations/instances_train2014.json'
# cfg.train_data_dir = 'train2014'
# cfg.val_file_list = 'annotations/instances_val2014.json'
# cfg.val_data_dir = 'val2014'
# elif 'coco2017' in cfg.dataset:
# cfg.train_file_list = 'annotations/instances_train2017.json'
# cfg.train_data_dir = 'train2017'
# cfg.val_file_list = 'annotations/instances_val2017.json'
# cfg.val_data_dir = 'val2017'
# else:
# raise NotImplementedError('Dataset {} not supported'.format(
# cfg.dataset))
if mode == 'train':
cfg.train_file_list = os.path.join(cfg.data_dir, cfg.train_file_list)
......@@ -156,7 +156,7 @@ class DataSetReader(object):
h, w, _ = im.shape
im_scale_x = size / float(w)
im_scale_y = size / float(h)
out_img = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=cv2.INTER_CUBIC)
out_img = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=cv2.INTER_LINEAR)
mean = np.array(mean).reshape((1, 1, -1))
std = np.array(std).reshape((1, 1, -1))
out_img = (out_img / 255.0 - mean) / std
......@@ -184,11 +184,6 @@ class DataSetReader(object):
im, gt_boxes, gt_labels, gt_scores = image_utils.image_augment(im, gt_boxes, gt_labels, gt_scores, size, mean)
# h, w, _ = im.shape
# im_scale_x = size / float(w)
# im_scale_y = size / float(h)
# im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=cv2.INTER_CUBIC)
mean = np.array(mean).reshape((1, 1, -1))
std = np.array(std).reshape((1, 1, -1))
out_img = (im / 255.0 - mean) / std
......
......@@ -115,9 +115,9 @@ def parse_args():
# TRAIN TEST INFER
add_arg('input_size', int, 608, "Image input size of YOLOv3.")
add_arg('random_shape', bool, True, "Resize to random shape for train reader.")
add_arg('label_smooth', bool, True, "Use label smooth in class label.")
add_arg('no_mixup_iter', int, 40000, "Disable mixup in last N iter.")
add_arg('valid_thresh', float, 0.01, "Valid confidence score for NMS.")
add_arg('label_smooth', bool, False, "Use label smooth in class label.")
add_arg('no_mixup_iter', int, 500200, "Disable mixup in last N iter.")
add_arg('valid_thresh', float, 0.005, "Valid confidence score for NMS.")
add_arg('nms_thresh', float, 0.45, "NMS threshold.")
add_arg('nms_topk', int, 400, "The number of boxes to perform NMS.")
add_arg('nms_posk', int, 100, "The number of boxes of NMS output.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册