preprocess.py 6.5 KB
Newer Older
L
lijianshe02 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
import cv2
import numpy as np


def generate_P_from_lmks(lmks, resize, w, h):
    """generate P from lmks"""
    diff_size = (64, 64)
    xs, ys = np.meshgrid(np.linspace(0, resize - 1, resize),
                         np.linspace(0, resize - 1, resize))
    xs = xs[None].repeat(68, axis=0)
    ys = ys[None].repeat(68, axis=0)
    fix = np.concatenate([ys, xs], axis=0)

    lmks = lmks.transpose(1, 0).reshape(-1, 1, 1)

    diff = fix - lmks
    diff = diff.transpose(1, 2, 0)
    diff = cv2.resize(diff, diff_size, interpolation=cv2.INTER_NEAREST)
    diff = diff.transpose(2, 0, 1).reshape(136, -1)
    norm = np.linalg.norm(diff, axis=0)
    P_np = diff / norm

    return P_np


def copy_area(tar, src, lms):
    rect = [
        int(min(lms[:, 1])) - 16,
        int(min(lms[:, 0])) - 16,
        int(max(lms[:, 1])) + 16 + 1,
        int(max(lms[:, 0])) + 16 + 1
    ]
    tar[rect[1]:rect[3], rect[0]:rect[2]] = \
        src[rect[1]:rect[3], rect[0]:rect[2]]
    src[rect[1]:rect[3], rect[0]:rect[2]] = 0


def rebound_box(mask, mask_B, mask_face):
    """solver ps"""
    index_tmp = mask.nonzero()
    x_index = index_tmp[0]
    y_index = index_tmp[1]
    index_tmp = mask_B.nonzero()
    x_B_index = index_tmp[0]
    y_B_index = index_tmp[1]
    mask_temp = np.copy(mask)
    mask_B_temp = np.copy(mask_B)
    mask_temp[min(x_index) - 16:max(x_index) + 17, min(y_index) - 16:max(y_index) + 17] =\
        mask_face[min(x_index) -
                    16:max(x_index) +
                    17, min(y_index) -
                    16:max(y_index) +
                    17]
    mask_B_temp[min(x_B_index) - 16:max(x_B_index) + 17, min(y_B_index) - 16:max(y_B_index) + 17] =\
        mask_face[min(x_B_index) -
                    16:max(x_B_index) +
                    17, min(y_B_index) -
                    16:max(y_B_index) +
                    17]
    return mask_temp, mask_B_temp


def calculate_consis_mask(mask, mask_B):
    h_a, w_a = mask.shape[1:]
    h_b, w_b = mask_B.shape[1:]
    mask_transpose = np.transpose(mask, (1, 2, 0))
    mask_B_transpose = np.transpose(mask_B, (1, 2, 0))
    mask = cv2.resize(mask_transpose,
                      dsize=(w_a // 4, h_a // 4),
                      interpolation=cv2.INTER_NEAREST)
    mask = np.transpose(mask, (2, 0, 1))
    mask_B = cv2.resize(mask_B_transpose,
                        dsize=(w_b // 4, h_b // 4),
                        interpolation=cv2.INTER_NEAREST)
    mask_B = np.transpose(mask_B, (2, 0, 1))
    """calculate consistency mask between images"""
    h_a, w_a = mask.shape[1:]
    h_b, w_b = mask_B.shape[1:]

    mask_lip = mask[0]
    mask_skin = mask[1]
    mask_eye = mask[2]

    mask_B_lip = mask_B[0]
    mask_B_skin = mask_B[1]
    mask_B_eye = mask_B[2]

    maskA_one_hot = np.zeros((h_a * w_a, 3))
    maskA_one_hot[:, 0] = mask_skin.flatten()
    maskA_one_hot[:, 1] = mask_eye.flatten()
    maskA_one_hot[:, 2] = mask_lip.flatten()

    maskB_one_hot = np.zeros((h_b * w_b, 3))
    maskB_one_hot[:, 0] = mask_B_skin.flatten()
    maskB_one_hot[:, 1] = mask_B_eye.flatten()
    maskB_one_hot[:, 2] = mask_B_lip.flatten()

    con_mask = np.matmul(maskA_one_hot.reshape((h_a * w_a, 3)),
                         np.transpose(maskB_one_hot.reshape((h_b * w_b, 3))))
    con_mask = np.clip(con_mask, 0, 1)
    return con_mask


def cal_hist(image):
    """
        cal cumulative hist for channel list
    """
    hists = []
    for i in range(0, 3):
        channel = image[i]
        # channel = image[i, :, :]
        #channel = torch.from_numpy(channel)
        hist, _ = np.histogram(channel, bins=256, range=(0, 255))
        #hist = torch.histc(channel, bins=256, min=0, max=256)
        # refHist=hist.view(256,1)
        sum = hist.sum()
        pdf = [v / sum for v in hist]
        for i in range(1, 256):
            pdf[i] = pdf[i - 1] + pdf[i]
        hists.append(pdf)
    return hists


def cal_trans(ref, adj):
    """
        calculate transfer function
        algorithm refering to wiki item: Histogram matching
    """
    table = list(range(0, 256))
    for i in list(range(1, 256)):
        for j in list(range(1, 256)):
            if ref[i] >= adj[j - 1] and ref[i] <= adj[j]:
                table[i] = j
                break
    table[255] = 255
    return table


def histogram_matching(dstImg, refImg, index):
    """
        perform histogram matching
        dstImg is transformed to have the same the histogram with refImg's
        index[0], index[1]: the index of pixels that need to be transformed in dstImg
        index[2], index[3]: the index of pixels that to compute histogram in refImg
    """
    dst_align = [dstImg[i, index[0], index[1]] for i in range(0, 3)]
    ref_align = [refImg[i, index[2], index[3]] for i in range(0, 3)]
    hist_ref = cal_hist(ref_align)
    hist_dst = cal_hist(dst_align)
    tables = [cal_trans(hist_dst[i], hist_ref[i]) for i in range(0, 3)]

    mid = dst_align.copy()
    for i in range(0, 3):
        for k in range(0, len(index[0])):
            dst_align[i][k] = tables[i][int(mid[i][k])]

    for i in range(0, 3):
        dstImg[i, index[0], index[1]] = dst_align[i]

    return dstImg


def hisMatch(input_data, target_data, mask_src, mask_tar, index):
    """solver ps"""
    mask_src = np.float32(np.clip(mask_src, 0, 1))
    mask_tar = np.float32(np.clip(mask_tar, 0, 1))
    input_masked = np.float32(input_data) * mask_src
    target_masked = np.float32(target_data) * mask_tar
    input_match = histogram_matching(input_masked, target_masked, index)
    return input_match


def mask_preprocess(mask, mask_B):
    """solver ps"""
    index_tmp = mask.nonzero()
    x_index = index_tmp[0]
    y_index = index_tmp[1]
    index_tmp = mask_B.nonzero()
    x_B_index = index_tmp[0]
    y_B_index = index_tmp[1]
    index = [x_index, y_index, x_B_index, y_B_index]
    index_2 = [x_B_index, y_B_index, x_index, y_index]
    return [mask, mask_B, index, index_2]


def generate_mask_aug(mask, lmks):

    lms_eye_left = lmks[42:48]
    lms_eye_right = lmks[36:42]

    mask_eye_left = np.zeros_like(mask)
    mask_eye_right = np.zeros_like(mask)

    mask_face = np.float32(mask == 1) + np.float32(mask == 6)

    copy_area(mask_eye_left, mask_face, lms_eye_left)
    copy_area(mask_eye_right, mask_face, lms_eye_right)

    mask_skin = mask_face

    mask_lip = np.float32(mask == 7) + np.float32(mask == 9)

    mask_eye = mask_eye_left + mask_eye_right
    mask_aug = np.concatenate(
        (np.expand_dims(mask_lip, 0), np.expand_dims(
            mask_skin, 0), np.expand_dims(mask_eye, 0)), 0)

    return mask_aug