data.py 7.1 KB
Newer Older
H
HypoX64 已提交
1
import random
H
hypox64 已提交
2
import os
HypoX64's avatar
preview  
HypoX64 已提交
3 4 5
import numpy as np
import torch
import torchvision.transforms as transforms
H
HypoX64 已提交
6
import cv2
H
hypox64 已提交
7
from . import image_processing as impro
H
hypox64 已提交
8
from . import degradater
HypoX64's avatar
preview  
HypoX64 已提交
9 10 11 12
transform = transforms.Compose([  
    transforms.ToTensor(),  
    transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))  
    ]  
H
BVDNet  
hypox64 已提交
13
) 
HypoX64's avatar
preview  
HypoX64 已提交
14

H
BVDNet  
hypox64 已提交
15 16 17 18 19 20 21 22
def to_tensor(data,gpu_id):
    data = torch.from_numpy(data)
    if gpu_id != '-1':
        data = data.cuda()
    return data


def tensor2im(image_tensor, imtype=np.uint8, gray=False, rgb2bgr = True ,is0_1 = False, batch_index=0):
HypoX64's avatar
preview  
HypoX64 已提交
23
    image_tensor =image_tensor.data
H
BVDNet  
hypox64 已提交
24
    image_numpy = image_tensor[batch_index].cpu().float().numpy()
25
    
H
HypoX64 已提交
26 27
    if not is0_1:
        image_numpy = (image_numpy + 1)/2.0
28 29 30 31 32 33 34 35 36 37 38 39
    image_numpy = np.clip(image_numpy * 255.0,0,255) 

    # gray -> output 1ch
    if gray:
        h, w = image_numpy.shape[1:]
        image_numpy = image_numpy.reshape(h,w)
        return image_numpy.astype(imtype)

    # output 3ch
    if image_numpy.shape[0] == 1:
        image_numpy = np.tile(image_numpy, (3, 1, 1))
    image_numpy = image_numpy.transpose((1, 2, 0))  
H
hypox64 已提交
40
    if rgb2bgr and not gray:
HypoX64's avatar
preview  
HypoX64 已提交
41 42 43 44
        image_numpy = image_numpy[...,::-1]-np.zeros_like(image_numpy)
    return image_numpy.astype(imtype)


45
def im2tensor(image_numpy, imtype=np.uint8, gray=False,bgr2rgb = True, reshape = True, gpu_id = 0,  use_transform = True,is0_1 = True):
H
hypox64 已提交
46 47 48 49
    
    if gray:
        h, w = image_numpy.shape
        image_numpy = (image_numpy/255.0-0.5)/0.5
H
hypox64 已提交
50
        image_tensor = torch.from_numpy(image_numpy).float()
H
hypox64 已提交
51
        if reshape:
H
hypox64 已提交
52
            image_tensor = image_tensor.reshape(1,1,h,w)
H
hypox64 已提交
53 54 55 56 57 58 59
    else:
        h, w ,ch = image_numpy.shape
        if bgr2rgb:
            image_numpy = image_numpy[...,::-1]-np.zeros_like(image_numpy)
        if use_transform:
            image_tensor = transform(image_numpy)
        else:
H
hypox64 已提交
60 61 62 63
            if is0_1:
                image_numpy = image_numpy/255.0
            else:
                image_numpy = (image_numpy/255.0-0.5)/0.5
H
hypox64 已提交
64 65 66
            image_numpy = image_numpy.transpose((2, 0, 1))
            image_tensor = torch.from_numpy(image_numpy).float()
        if reshape:
H
hypox64 已提交
67
            image_tensor = image_tensor.reshape(1,ch,h,w)
68
    if gpu_id != '-1':
HypoX64's avatar
preview  
HypoX64 已提交
69
        image_tensor = image_tensor.cuda()
H
hypox64 已提交
70 71
    return image_tensor

H
hypox64 已提交
72 73 74 75 76
def shuffledata(data,target):
    state = np.random.get_state()
    np.random.shuffle(data)
    np.random.set_state(state)
    np.random.shuffle(target)
H
HypoX64 已提交
77 78


H
BVDNet  
hypox64 已提交
79
def random_transform_single_mask(img,out_shape):
H
hypox64 已提交
80 81
    out_h,out_w = out_shape
    img = cv2.resize(img,(int(out_w*random.uniform(1.1, 1.5)),int(out_h*random.uniform(1.1, 1.5))))
H
HypoX64 已提交
82
    h,w = img.shape[:2]
H
hypox64 已提交
83 84 85 86 87 88 89 90 91 92 93
    h_move = int((h-out_h)*random.random())
    w_move = int((w-out_w)*random.random())
    img = img[h_move:h_move+out_h,w_move:w_move+out_w]
    if random.random()<0.5:
        if random.random()<0.5:
            img = img[:,::-1]
        else:
            img = img[::-1,:]
    if img.shape[0] != out_h or img.shape[1]!= out_w :
        img = cv2.resize(img,(out_w,out_h))
    return img
H
HypoX64 已提交
94

H
BVDNet  
hypox64 已提交
95 96 97 98 99
def get_transform_params():
    crop_flag  = True
    rotat_flag = np.random.random()<0.2
    color_flag = True
    flip_flag  = np.random.random()<0.2
H
hypox64 已提交
100 101
    degradate_flag  = np.random.random()<0.5
    flag_dict = {'crop':crop_flag,'rotat':rotat_flag,'color':color_flag,'flip':flip_flag,'degradate':degradate_flag}
H
BVDNet  
hypox64 已提交
102 103 104 105 106 107
    
    crop_rate = [np.random.random(),np.random.random()]
    rotat_rate = np.random.random()
    color_rate = [np.random.uniform(-0.05,0.05),np.random.uniform(-0.05,0.05),np.random.uniform(-0.05,0.05),
        np.random.uniform(-0.05,0.05),np.random.uniform(-0.05,0.05)]
    flip_rate = np.random.random()
H
hypox64 已提交
108 109
    degradate_params = degradater.get_random_degenerate_params(mod='weaker_1')
    rate_dict = {'crop':crop_rate,'rotat':rotat_rate,'color':color_rate,'flip':flip_rate,'degradate':degradate_params}
H
BVDNet  
hypox64 已提交
110 111 112 113 114 115

    return {'flag':flag_dict,'rate':rate_dict}

def random_transform_single_image(img,finesize,params=None,test_flag = False):
    if params is None:
        params = get_transform_params()
H
hypox64 已提交
116

H
BVDNet  
hypox64 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
    if params['flag']['crop']:
        h,w = img.shape[:2]
        h_move = int((h-finesize)*params['rate']['crop'][0])
        w_move = int((w-finesize)*params['rate']['crop'][1])
        img = img[h_move:h_move+finesize,w_move:w_move+finesize]
    
    if test_flag:
        return img

    if params['flag']['rotat']:
        h,w = img.shape[:2]
        M = cv2.getRotationMatrix2D((w/2,h/2),90*int(4*params['rate']['rotat']),1)
        img = cv2.warpAffine(img,M,(w,h))

    if params['flag']['color']:
        img = impro.color_adjust(img,params['rate']['color'][0],params['rate']['color'][1],
            params['rate']['color'][2],params['rate']['color'][3],params['rate']['color'][4])

    if params['flag']['flip']:
        img = img[:,::-1,:]

H
hypox64 已提交
138 139
    if params['flag']['degradate']:
        img = degradater.degradate(img,params['rate']['degradate'])
H
BVDNet  
hypox64 已提交
140 141 142 143 144 145 146 147

    #check shape
    if img.shape[0]!= finesize or img.shape[1]!= finesize:
        img = cv2.resize(img,(finesize,finesize))
        print('warning! shape error.')
    return img

def random_transform_pair_image(img,mask,finesize,test_flag = False):
H
hypox64 已提交
148 149 150 151 152 153 154 155 156 157 158
    #random scale
    if random.random()<0.5:
        h,w = img.shape[:2]
        loadsize = min((h,w))
        a = (float(h)/float(w))*random.uniform(0.9, 1.1)
        if h<w:
            mask = cv2.resize(mask, (int(loadsize/a),loadsize))
            img = cv2.resize(img, (int(loadsize/a),loadsize))
        else:
            mask = cv2.resize(mask, (loadsize,int(loadsize*a)))
            img = cv2.resize(img, (loadsize,int(loadsize*a)))
H
HypoX64 已提交
159 160 161 162 163 164 165

    #random crop
    h,w = img.shape[:2]
    h_move = int((h-finesize)*random.random())
    w_move = int((w-finesize)*random.random())
    img_crop = img[h_move:h_move+finesize,w_move:w_move+finesize]
    mask_crop = mask[h_move:h_move+finesize,w_move:w_move+finesize]
H
hypox64 已提交
166 167 168

    if test_flag:
        return img_crop,mask_crop
H
HypoX64 已提交
169 170 171 172 173 174 175 176 177 178 179
    
    #random rotation
    if random.random()<0.2:
        h,w = img_crop.shape[:2]
        M = cv2.getRotationMatrix2D((w/2,h/2),90*int(4*random.random()),1)
        img = cv2.warpAffine(img_crop,M,(w,h))
        mask = cv2.warpAffine(mask_crop,M,(w,h))
    else:
        img,mask = img_crop,mask_crop

    #random color
H
hypox64 已提交
180
    img = impro.color_adjust(img,ran=True)
H
HypoX64 已提交
181 182 183 184 185 186 187 188 189

    #random flip
    if random.random()<0.5:
        if random.random()<0.5:
            img = img[:,::-1,:]
            mask = mask[:,::-1]
        else:
            img = img[::-1,:,:]
            mask = mask[::-1,:]
H
hypox64 已提交
190 191

    #random blur
H
hypox64 已提交
192
    if random.random()<0.5:
H
hypox64 已提交
193
        img = impro.dctblur(img,random.randint(1,15))
H
hypox64 已提交
194 195 196 197 198 199
        
    #check shape
    if img.shape[0]!= finesize or img.shape[1]!= finesize or mask.shape[0]!= finesize or mask.shape[1]!= finesize:
        img = cv2.resize(img,(finesize,finesize))
        mask = cv2.resize(mask,(finesize,finesize))
        print('warning! shape error.')
H
HypoX64 已提交
200 201
    return img,mask

H
hypox64 已提交
202
def showresult(img1,img2,img3,name,is0_1 = False):
H
HypoX64 已提交
203 204
    size = img1.shape[3]
    showimg=np.zeros((size,size*3,3))
H
hypox64 已提交
205 206 207
    showimg[0:size,0:size] = tensor2im(img1,rgb2bgr = False, is0_1 = is0_1)
    showimg[0:size,size:size*2] = tensor2im(img2,rgb2bgr = False, is0_1 = is0_1)
    showimg[0:size,size*2:size*3] = tensor2im(img3,rgb2bgr = False, is0_1 = is0_1)
H
HypoX64 已提交
208
    cv2.imwrite(name, showimg)