dataset.py 3.9 KB
Newer Older
C
chenyun 已提交
1 2
from __future__ import  absolute_import
from __future__ import  division
C
init  
chenyuntc 已提交
3
import torch as t
C
chenyun 已提交
4
from data.voc_dataset import VOCBboxDataset
C
init  
chenyuntc 已提交
5 6
from skimage import transform as sktsf
from torchvision import transforms as tvtsf
C
chenyun 已提交
7
from data import util
C
fix  
cy 已提交
8
import numpy as np
C
chenyuntc 已提交
9
from utils.config import opt
C
chenyuntc 已提交
10 11


C
chenyuntc 已提交
12 13
def inverse_normalize(img):
    if opt.caffe_pretrain:
C
chenyuntc 已提交
14
        img = img + (np.array([122.7717, 115.9465, 102.9801]).reshape(3, 1, 1))
C
chenyuntc 已提交
15 16
        return img[::-1, :, :]
    # approximate un-normalize for visualize
C
chenyuntc 已提交
17 18
    return (img * 0.225 + 0.45).clip(min=0, max=1) * 255

C
chenyuntc 已提交
19 20 21 22 23 24 25

def pytorch_normalze(img):
    """
    https://github.com/pytorch/vision/issues/223
    return appr -1~1 RGB
    """
    normalize = tvtsf.Normalize(mean=[0.485, 0.456, 0.406],
C
chenyuntc 已提交
26
                                std=[0.229, 0.224, 0.225])
C
chenyuntc 已提交
27 28 29
    img = normalize(t.from_numpy(img))
    return img.numpy()

C
chenyuntc 已提交
30

C
chenyuntc 已提交
31 32 33 34
def caffe_normalize(img):
    """
    return appr -125-125 BGR
    """
C
chenyuntc 已提交
35 36 37
    img = img[[2, 1, 0], :, :]  # RGB-BGR
    img = img * 255
    mean = np.array([122.7717, 115.9465, 102.9801]).reshape(3, 1, 1)
C
chenyuntc 已提交
38 39 40
    img = (img - mean).astype(np.float32, copy=True)
    return img

C
chenyuntc 已提交
41

C
chenyuntc 已提交
42
def preprocess(img, min_size=600, max_size=1000):
C
init  
chenyuntc 已提交
43 44 45 46
    """Preprocess an image for feature extraction.

    The length of the shorter edge is scaled to :obj:`self.min_size`.
    After the scaling, if the length of the longer edge is longer than
C
chenyuntc 已提交
47
    :param min_size:
C
init  
chenyuntc 已提交
48 49 50 51 52 53 54 55 56 57 58
    :obj:`self.max_size`, the image is scaled to fit the longer edge
    to :obj:`self.max_size`.

    After resizing the image, the image is subtracted by a mean image value
    :obj:`self.mean`.

    Args:
        img (~numpy.ndarray): An image. This is in CHW and RGB format.
            The range of its value is :math:`[0, 255]`.

    Returns:
C
chenyuntc 已提交
59
        ~numpy.ndarray: A preprocessed image.
C
init  
chenyuntc 已提交
60 61 62 63 64 65

    """
    C, H, W = img.shape
    scale1 = min_size / min(H, W)
    scale2 = max_size / max(H, W)
    scale = min(scale1, scale2)
C
chenyuntc 已提交
66
    img = img / 255.
C
chenyun 已提交
67
    img = sktsf.resize(img, (C, H * scale, W * scale), mode='reflect',anti_aliasing=False)
C
chenyuntc 已提交
68 69 70 71 72 73 74
    # both the longer and shorter should be less than
    # max_size and min_size
    if opt.caffe_pretrain:
        normalize = caffe_normalize
    else:
        normalize = pytorch_normalze
    return normalize(img)
C
init  
chenyuntc 已提交
75

C
chenyuntc 已提交
76

C
init  
chenyuntc 已提交
77 78
class Transform(object):

C
chenyuntc 已提交
79
    def __init__(self, min_size=600, max_size=1000):
C
init  
chenyuntc 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
        self.min_size = min_size
        self.max_size = max_size

    def __call__(self, in_data):
        img, bbox, label = in_data
        _, H, W = img.shape
        img = preprocess(img, self.min_size, self.max_size)
        _, o_H, o_W = img.shape
        scale = o_H / H
        bbox = util.resize_bbox(bbox, (H, W), (o_H, o_W))

        # horizontally flip
        img, params = util.random_flip(
            img, x_random=True, return_param=True)
        bbox = util.flip_bbox(
            bbox, (o_H, o_W), x_flip=params['x_flip'])

        return img, bbox, label, scale

C
chenyuntc 已提交
99

C
chenyuntc 已提交
100
class Dataset:
C
init  
chenyuntc 已提交
101 102 103
    def __init__(self, opt):
        self.opt = opt
        self.db = VOCBboxDataset(opt.voc_data_dir)
C
chenyuntc 已提交
104
        self.tsf = Transform(opt.min_size, opt.max_size)
C
init  
chenyuntc 已提交
105 106

    def __getitem__(self, idx):
C
backup  
chenyuntc 已提交
107
        ori_img, bbox, label, difficult = self.db.get_example(idx)
C
init  
chenyuntc 已提交
108

C
backup  
chenyuntc 已提交
109
        img, bbox, label, scale = self.tsf((ori_img, bbox, label))
C
chenyuntc 已提交
110
        # TODO: check whose stride is negative to fix this instead copy all
C
backup  
chenyuntc 已提交
111
        # some of the strides of a given numpy array are negative.
C
chenyuntc 已提交
112
        return img.copy(), bbox.copy(), label.copy(), scale
C
backup  
chenyuntc 已提交
113

C
backup  
chenyuntc 已提交
114 115 116
    def __len__(self):
        return len(self.db)

C
chenyuntc 已提交
117

C
chenyuntc 已提交
118 119
class TestDataset:
    def __init__(self, opt, split='test', use_difficult=True):
C
backup  
chenyuntc 已提交
120
        self.opt = opt
C
chenyuntc 已提交
121
        self.db = VOCBboxDataset(opt.voc_data_dir, split=split, use_difficult=use_difficult)
C
backup  
chenyuntc 已提交
122 123 124 125

    def __getitem__(self, idx):
        ori_img, bbox, label, difficult = self.db.get_example(idx)
        img = preprocess(ori_img)
C
chenyuntc 已提交
126
        return img, ori_img.shape[1:], bbox, label, difficult
C
backup  
chenyuntc 已提交
127

C
backup  
chenyuntc 已提交
128 129
    def __len__(self):
        return len(self.db)