提交 53440c59 编写于 作者: D dengkaipeng

change mean/std to /255.0

上级 19bbe367
...@@ -121,7 +121,7 @@ def random_interp(img, size, interp=None): ...@@ -121,7 +121,7 @@ def random_interp(img, size, interp=None):
img = cv2.resize(img, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=interp) img = cv2.resize(img, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=interp)
return img return img
def random_expand(img, gtboxes, max_ratio=4., fill=None, keep_ratio=True, thresh=0.5): def random_expand(img, gtboxes, max_ratio=2., fill=None, keep_ratio=True, thresh=0.5):
if random.random() > thresh: if random.random() > thresh:
return img, gtboxes return img, gtboxes
......
...@@ -38,23 +38,23 @@ class DataSetReader(object): ...@@ -38,23 +38,23 @@ class DataSetReader(object):
self.has_parsed_categpry = False self.has_parsed_categpry = False
def _parse_dataset_dir(self, mode): def _parse_dataset_dir(self, mode):
cfg.data_dir = "dataset/coco" # cfg.data_dir = "dataset/coco"
cfg.train_file_list = 'annotations/instances_val2017.json' # cfg.train_file_list = 'annotations/instances_val2017.json'
cfg.train_data_dir = 'val2017' # cfg.train_data_dir = 'val2017'
cfg.dataset = "coco2017" # cfg.dataset = "coco2017"
# if 'coco2014' in cfg.dataset: if 'coco2014' in cfg.dataset:
# cfg.train_file_list = 'annotations/instances_train2014.json' cfg.train_file_list = 'annotations/instances_train2014.json'
# cfg.train_data_dir = 'train2014' cfg.train_data_dir = 'train2014'
# cfg.val_file_list = 'annotations/instances_val2014.json' cfg.val_file_list = 'annotations/instances_val2014.json'
# cfg.val_data_dir = 'val2014' cfg.val_data_dir = 'val2014'
# elif 'coco2017' in cfg.dataset: elif 'coco2017' in cfg.dataset:
# cfg.train_file_list = 'annotations/instances_train2017.json' cfg.train_file_list = 'annotations/instances_train2017.json'
# cfg.train_data_dir = 'train2017' cfg.train_data_dir = 'train2017'
# cfg.val_file_list = 'annotations/instances_val2017.json' cfg.val_file_list = 'annotations/instances_val2017.json'
# cfg.val_data_dir = 'val2017' cfg.val_data_dir = 'val2017'
# else: else:
# raise NotImplementedError('Dataset {} not supported'.format( raise NotImplementedError('Dataset {} not supported'.format(
# cfg.dataset)) cfg.dataset))
if mode == 'train': if mode == 'train':
cfg.train_file_list = os.path.join(cfg.data_dir, cfg.train_file_list) cfg.train_file_list = os.path.join(cfg.data_dir, cfg.train_file_list)
...@@ -157,10 +157,11 @@ class DataSetReader(object): ...@@ -157,10 +157,11 @@ class DataSetReader(object):
im_scale_x = size / float(w) im_scale_x = size / float(w)
im_scale_y = size / float(h) im_scale_y = size / float(h)
out_img = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=cv2.INTER_LINEAR) out_img = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=cv2.INTER_LINEAR)
mean = np.array(mean).reshape((1, 1, -1)) # mean = np.array(mean).reshape((1, 1, -1))
std = np.array(std).reshape((1, 1, -1)) # std = np.array(std).reshape((1, 1, -1))
out_img = (out_img / 255.0 - mean) / std # out_img = (out_img / 255.0 - mean) / std
out_img = out_img.transpose((2, 0, 1)) # out_img = out_img.transpose((2, 0, 1))
out_img = im.astype('float32').transpose((2, 0, 1)) / 255.0
return (out_img, int(img['id']), (h, w)) return (out_img, int(img['id']), (h, w))
...@@ -184,10 +185,11 @@ class DataSetReader(object): ...@@ -184,10 +185,11 @@ class DataSetReader(object):
im, gt_boxes, gt_labels, gt_scores = image_utils.image_augment(im, gt_boxes, gt_labels, gt_scores, size, mean) im, gt_boxes, gt_labels, gt_scores = image_utils.image_augment(im, gt_boxes, gt_labels, gt_scores, size, mean)
mean = np.array(mean).reshape((1, 1, -1)) # mean = np.array(mean).reshape((1, 1, -1))
std = np.array(std).reshape((1, 1, -1)) # std = np.array(std).reshape((1, 1, -1))
out_img = (im / 255.0 - mean) / std # out_img = (im / 255.0 - mean) / std
out_img = out_img.transpose((2, 0, 1)).astype('float32') # out_img = out_img.transpose((2, 0, 1)).astype('float32')
out_img = im.astype('float32').transpose((2, 0, 1)) / 255.0
return (out_img, gt_boxes, gt_labels, gt_scores) return (out_img, gt_boxes, gt_labels, gt_scores)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册