dataset.py 3.9 KB
Newer Older
C
init  
chenyuntc 已提交
1 2 3 4 5
import torch as t
from .voc_dataset import VOCBboxDataset
from skimage import transform as sktsf
from torchvision import transforms as tvtsf
from . import util
C
fix  
cy 已提交
6
import numpy as np
C
chenyuntc 已提交
7
from utils.config import opt
C
chenyuntc 已提交
8 9


C
chenyuntc 已提交
10 11
def inverse_normalize(img):
    if opt.caffe_pretrain:
C
chenyuntc 已提交
12
        img = img + (np.array([122.7717, 115.9465, 102.9801]).reshape(3, 1, 1))
C
chenyuntc 已提交
13 14
        return img[::-1, :, :]
    # approximate un-normalize for visualize
C
chenyuntc 已提交
15 16
    return (img * 0.225 + 0.45).clip(min=0, max=1) * 255

C
chenyuntc 已提交
17 18 19 20 21 22 23

def pytorch_normalze(img):
    """
    https://github.com/pytorch/vision/issues/223
    return appr -1~1 RGB
    """
    normalize = tvtsf.Normalize(mean=[0.485, 0.456, 0.406],
C
chenyuntc 已提交
24
                                std=[0.229, 0.224, 0.225])
C
chenyuntc 已提交
25 26 27
    img = normalize(t.from_numpy(img))
    return img.numpy()

C
chenyuntc 已提交
28

C
chenyuntc 已提交
29 30 31 32
def caffe_normalize(img):
    """
    return appr -125-125 BGR
    """
C
chenyuntc 已提交
33 34 35
    img = img[[2, 1, 0], :, :]  # RGB-BGR
    img = img * 255
    mean = np.array([122.7717, 115.9465, 102.9801]).reshape(3, 1, 1)
C
chenyuntc 已提交
36 37 38
    img = (img - mean).astype(np.float32, copy=True)
    return img

C
chenyuntc 已提交
39

C
chenyuntc 已提交
40
def preprocess(img, min_size=600, max_size=1000):
C
init  
chenyuntc 已提交
41 42 43 44
    """Preprocess an image for feature extraction.

    The length of the shorter edge is scaled to :obj:`self.min_size`.
    After the scaling, if the length of the longer edge is longer than
C
chenyuntc 已提交
45
    :param min_size:
C
init  
chenyuntc 已提交
46 47 48 49 50 51 52 53 54
    :obj:`self.max_size`, the image is scaled to fit the longer edge
    to :obj:`self.max_size`.

    After resizing the image, the image is subtracted by a mean image value
    :obj:`self.mean`.

    Args:
        img (~numpy.ndarray): An image. This is in CHW and RGB format.
            The range of its value is :math:`[0, 255]`.
C
chenyuntc 已提交
55 56
         (~numpy.ndarray): An image. This is in CHW and RGB format.
            The range of its value is :math:`[0, 255]`.
C
init  
chenyuntc 已提交
57 58 59 60 61 62 63 64 65 66

    Returns:
        ~numpy.ndarray:
        A preprocessed image.

    """
    C, H, W = img.shape
    scale1 = min_size / min(H, W)
    scale2 = max_size / max(H, W)
    scale = min(scale1, scale2)
C
chenyuntc 已提交
67
    img = img / 255.
C
chenyuntc 已提交
68
    img = sktsf.resize(img, (C, H * scale, W * scale), mode='reflect')
C
chenyuntc 已提交
69 70 71 72 73 74 75
    # both the longer and shorter should be less than
    # max_size and min_size
    if opt.caffe_pretrain:
        normalize = caffe_normalize
    else:
        normalize = pytorch_normalze
    return normalize(img)
C
init  
chenyuntc 已提交
76

C
chenyuntc 已提交
77

C
init  
chenyuntc 已提交
78 79
class Transform(object):

C
chenyuntc 已提交
80
    def __init__(self, min_size=600, max_size=1000):
C
init  
chenyuntc 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
        self.min_size = min_size
        self.max_size = max_size

    def __call__(self, in_data):
        img, bbox, label = in_data
        _, H, W = img.shape
        img = preprocess(img, self.min_size, self.max_size)
        _, o_H, o_W = img.shape
        scale = o_H / H
        bbox = util.resize_bbox(bbox, (H, W), (o_H, o_W))

        # horizontally flip
        img, params = util.random_flip(
            img, x_random=True, return_param=True)
        bbox = util.flip_bbox(
            bbox, (o_H, o_W), x_flip=params['x_flip'])

        return img, bbox, label, scale

C
chenyuntc 已提交
100

C
chenyuntc 已提交
101
class Dataset:
C
init  
chenyuntc 已提交
102 103 104
    def __init__(self, opt):
        self.opt = opt
        self.db = VOCBboxDataset(opt.voc_data_dir)
C
chenyuntc 已提交
105
        self.tsf = Transform(opt.min_size, opt.max_size)
C
init  
chenyuntc 已提交
106 107

    def __getitem__(self, idx):
C
backup  
chenyuntc 已提交
108
        ori_img, bbox, label, difficult = self.db.get_example(idx)
C
init  
chenyuntc 已提交
109

C
backup  
chenyuntc 已提交
110
        img, bbox, label, scale = self.tsf((ori_img, bbox, label))
C
chenyuntc 已提交
111
        # TODO: check whose stride is negative to fix this instead copy all
C
backup  
chenyuntc 已提交
112
        # some of the strides of a given numpy array are negative.
C
chenyuntc 已提交
113
        return img.copy(), bbox.copy(), label.copy(), scale
C
backup  
chenyuntc 已提交
114

C
backup  
chenyuntc 已提交
115 116 117
    def __len__(self):
        return len(self.db)

C
chenyuntc 已提交
118

C
chenyuntc 已提交
119 120
class TestDataset:
    def __init__(self, opt, split='test', use_difficult=True):
C
backup  
chenyuntc 已提交
121
        self.opt = opt
C
chenyuntc 已提交
122
        self.db = VOCBboxDataset(opt.voc_data_dir, split=split, use_difficult=use_difficult)
C
backup  
chenyuntc 已提交
123 124 125 126

    def __getitem__(self, idx):
        ori_img, bbox, label, difficult = self.db.get_example(idx)
        img = preprocess(ori_img)
C
chenyuntc 已提交
127
        return img, ori_img.shape[1:], bbox, label, difficult
C
backup  
chenyuntc 已提交
128

C
backup  
chenyuntc 已提交
129 130
    def __len__(self):
        return len(self.db)