data.py 7.0 KB
Newer Older
H
HypoX64 已提交
1
import random
H
hypox64 已提交
2
import os
HypoX64's avatar
preview  
HypoX64 已提交
3 4 5
import numpy as np
import torch
import torchvision.transforms as transforms
H
HypoX64 已提交
6
import cv2
H
hypox64 已提交
7
from . import image_processing as impro
H
hypox64 已提交
8
from . import degradater
HypoX64's avatar
preview  
HypoX64 已提交
9

H
BVDNet  
hypox64 已提交
10 11 12 13 14 15
def to_tensor(data,gpu_id):
    data = torch.from_numpy(data)
    if gpu_id != '-1':
        data = data.cuda()
    return data

16 17 18 19 20 21 22 23 24
def normalize(data):
    '''
    normalize to -1 ~ 1
    '''
    return (data.astype(np.float32)/255.0-0.5)/0.5

def anti_normalize(data):
    return np.clip((data*0.5+0.5)*255,0,255).astype(np.uint8)

H
hypox64 已提交
25
def tensor2im(image_tensor, gray=False, rgb2bgr = True ,is0_1 = False, batch_index=0):
HypoX64's avatar
preview  
HypoX64 已提交
26
    image_tensor =image_tensor.data
H
BVDNet  
hypox64 已提交
27
    image_numpy = image_tensor[batch_index].cpu().float().numpy()
28
    
H
HypoX64 已提交
29 30
    if not is0_1:
        image_numpy = (image_numpy + 1)/2.0
31 32 33 34 35 36
    image_numpy = np.clip(image_numpy * 255.0,0,255) 

    # gray -> output 1ch
    if gray:
        h, w = image_numpy.shape[1:]
        image_numpy = image_numpy.reshape(h,w)
H
hypox64 已提交
37
        return image_numpy.astype(np.uint8)
38 39 40 41 42

    # output 3ch
    if image_numpy.shape[0] == 1:
        image_numpy = np.tile(image_numpy, (3, 1, 1))
    image_numpy = image_numpy.transpose((1, 2, 0))  
H
hypox64 已提交
43
    if rgb2bgr and not gray:
HypoX64's avatar
preview  
HypoX64 已提交
44
        image_numpy = image_numpy[...,::-1]-np.zeros_like(image_numpy)
H
hypox64 已提交
45
    return image_numpy.astype(np.uint8)
HypoX64's avatar
preview  
HypoX64 已提交
46 47


H
hypox64 已提交
48
def im2tensor(image_numpy, gray=False,bgr2rgb = True, reshape = True, gpu_id = '-1',is0_1 = False):
H
hypox64 已提交
49 50 51
    if gray:
        h, w = image_numpy.shape
        image_numpy = (image_numpy/255.0-0.5)/0.5
H
hypox64 已提交
52
        image_tensor = torch.from_numpy(image_numpy).float()
H
hypox64 已提交
53
        if reshape:
H
hypox64 已提交
54
            image_tensor = image_tensor.reshape(1,1,h,w)
H
hypox64 已提交
55 56 57 58
    else:
        h, w ,ch = image_numpy.shape
        if bgr2rgb:
            image_numpy = image_numpy[...,::-1]-np.zeros_like(image_numpy)
H
hypox64 已提交
59 60
        if is0_1:
            image_numpy = image_numpy/255.0
H
hypox64 已提交
61
        else:
H
hypox64 已提交
62 63 64
            image_numpy = (image_numpy/255.0-0.5)/0.5
        image_numpy = image_numpy.transpose((2, 0, 1))
        image_tensor = torch.from_numpy(image_numpy).float()
H
hypox64 已提交
65
        if reshape:
H
hypox64 已提交
66
            image_tensor = image_tensor.reshape(1,ch,h,w)
67
    if gpu_id != '-1':
HypoX64's avatar
preview  
HypoX64 已提交
68
        image_tensor = image_tensor.cuda()
H
hypox64 已提交
69 70
    return image_tensor

H
hypox64 已提交
71 72 73 74 75
def shuffledata(data,target):
    state = np.random.get_state()
    np.random.shuffle(data)
    np.random.set_state(state)
    np.random.shuffle(target)
H
HypoX64 已提交
76

H
BVDNet  
hypox64 已提交
77
def random_transform_single_mask(img,out_shape):
H
hypox64 已提交
78 79
    out_h,out_w = out_shape
    img = cv2.resize(img,(int(out_w*random.uniform(1.1, 1.5)),int(out_h*random.uniform(1.1, 1.5))))
H
HypoX64 已提交
80
    h,w = img.shape[:2]
H
hypox64 已提交
81 82 83 84 85 86 87 88 89 90 91
    h_move = int((h-out_h)*random.random())
    w_move = int((w-out_w)*random.random())
    img = img[h_move:h_move+out_h,w_move:w_move+out_w]
    if random.random()<0.5:
        if random.random()<0.5:
            img = img[:,::-1]
        else:
            img = img[::-1,:]
    if img.shape[0] != out_h or img.shape[1]!= out_w :
        img = cv2.resize(img,(out_w,out_h))
    return img
H
HypoX64 已提交
92

H
BVDNet  
hypox64 已提交
93 94 95 96 97
def get_transform_params():
    crop_flag  = True
    rotat_flag = np.random.random()<0.2
    color_flag = True
    flip_flag  = np.random.random()<0.2
H
hypox64 已提交
98 99
    degradate_flag  = np.random.random()<0.5
    flag_dict = {'crop':crop_flag,'rotat':rotat_flag,'color':color_flag,'flip':flip_flag,'degradate':degradate_flag}
H
BVDNet  
hypox64 已提交
100 101 102 103 104 105
    
    crop_rate = [np.random.random(),np.random.random()]
    rotat_rate = np.random.random()
    color_rate = [np.random.uniform(-0.05,0.05),np.random.uniform(-0.05,0.05),np.random.uniform(-0.05,0.05),
        np.random.uniform(-0.05,0.05),np.random.uniform(-0.05,0.05)]
    flip_rate = np.random.random()
H
hypox64 已提交
106
    degradate_params = degradater.get_random_degenerate_params(mod='weaker_2')
H
hypox64 已提交
107
    rate_dict = {'crop':crop_rate,'rotat':rotat_rate,'color':color_rate,'flip':flip_rate,'degradate':degradate_params}
H
BVDNet  
hypox64 已提交
108 109 110 111 112 113

    return {'flag':flag_dict,'rate':rate_dict}

def random_transform_single_image(img,finesize,params=None,test_flag = False):
    if params is None:
        params = get_transform_params()
H
hypox64 已提交
114 115 116
    
    if params['flag']['degradate']:
        img = degradater.degradate(img,params['rate']['degradate'])
H
hypox64 已提交
117

H
BVDNet  
hypox64 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
    if params['flag']['crop']:
        h,w = img.shape[:2]
        h_move = int((h-finesize)*params['rate']['crop'][0])
        w_move = int((w-finesize)*params['rate']['crop'][1])
        img = img[h_move:h_move+finesize,w_move:w_move+finesize]
    
    if test_flag:
        return img

    if params['flag']['rotat']:
        h,w = img.shape[:2]
        M = cv2.getRotationMatrix2D((w/2,h/2),90*int(4*params['rate']['rotat']),1)
        img = cv2.warpAffine(img,M,(w,h))

    if params['flag']['color']:
        img = impro.color_adjust(img,params['rate']['color'][0],params['rate']['color'][1],
            params['rate']['color'][2],params['rate']['color'][3],params['rate']['color'][4])

    if params['flag']['flip']:
        img = img[:,::-1,:]

    #check shape
    if img.shape[0]!= finesize or img.shape[1]!= finesize:
        img = cv2.resize(img,(finesize,finesize))
        print('warning! shape error.')
    return img

def random_transform_pair_image(img,mask,finesize,test_flag = False):
H
hypox64 已提交
146 147 148 149 150 151 152 153 154 155 156
    #random scale
    if random.random()<0.5:
        h,w = img.shape[:2]
        loadsize = min((h,w))
        a = (float(h)/float(w))*random.uniform(0.9, 1.1)
        if h<w:
            mask = cv2.resize(mask, (int(loadsize/a),loadsize))
            img = cv2.resize(img, (int(loadsize/a),loadsize))
        else:
            mask = cv2.resize(mask, (loadsize,int(loadsize*a)))
            img = cv2.resize(img, (loadsize,int(loadsize*a)))
H
HypoX64 已提交
157 158 159 160 161 162 163

    #random crop
    h,w = img.shape[:2]
    h_move = int((h-finesize)*random.random())
    w_move = int((w-finesize)*random.random())
    img_crop = img[h_move:h_move+finesize,w_move:w_move+finesize]
    mask_crop = mask[h_move:h_move+finesize,w_move:w_move+finesize]
H
hypox64 已提交
164 165 166

    if test_flag:
        return img_crop,mask_crop
H
HypoX64 已提交
167 168 169 170 171 172 173 174 175 176 177
    
    #random rotation
    if random.random()<0.2:
        h,w = img_crop.shape[:2]
        M = cv2.getRotationMatrix2D((w/2,h/2),90*int(4*random.random()),1)
        img = cv2.warpAffine(img_crop,M,(w,h))
        mask = cv2.warpAffine(mask_crop,M,(w,h))
    else:
        img,mask = img_crop,mask_crop

    #random color
H
hypox64 已提交
178
    img = impro.color_adjust(img,ran=True)
H
HypoX64 已提交
179 180 181 182 183 184 185 186 187

    #random flip
    if random.random()<0.5:
        if random.random()<0.5:
            img = img[:,::-1,:]
            mask = mask[:,::-1]
        else:
            img = img[::-1,:,:]
            mask = mask[::-1,:]
H
hypox64 已提交
188 189

    #random blur
H
hypox64 已提交
190
    if random.random()<0.5:
H
hypox64 已提交
191
        img = impro.dctblur(img,random.randint(1,15))
H
hypox64 已提交
192 193 194 195 196 197
        
    #check shape
    if img.shape[0]!= finesize or img.shape[1]!= finesize or mask.shape[0]!= finesize or mask.shape[1]!= finesize:
        img = cv2.resize(img,(finesize,finesize))
        mask = cv2.resize(mask,(finesize,finesize))
        print('warning! shape error.')
H
HypoX64 已提交
198 199
    return img,mask

H
hypox64 已提交
200
def showresult(img1,img2,img3,name,is0_1 = False):
H
HypoX64 已提交
201 202
    size = img1.shape[3]
    showimg=np.zeros((size,size*3,3))
H
hypox64 已提交
203 204 205
    showimg[0:size,0:size] = tensor2im(img1,rgb2bgr = False, is0_1 = is0_1)
    showimg[0:size,size:size*2] = tensor2im(img2,rgb2bgr = False, is0_1 = is0_1)
    showimg[0:size,size*2:size*3] = tensor2im(img3,rgb2bgr = False, is0_1 = is0_1)
H
HypoX64 已提交
206
    cv2.imwrite(name, showimg)