prepropess_data.py 1.4 KB
Newer Older
Eric.Lee2021's avatar
Eric.Lee2021 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
# -*- encoding: utf-8 -*-
#function : 训练样本预处理

import os
import os.path as osp
import cv2
from transform import *
from PIL import Image

if __name__ == "__main__":

    image_size = 256# 样本分辨率

    face_data = './CelebAMask-HQ/CelebA-HQ-img'
    face_sep_mask = './CelebAMask-HQ/CelebAMask-HQ-mask-anno'
    mask_path = './CelebAMask-HQ/mask_{}'.format(image_size)

    if not os.path.exists(mask_path):
        os.mkdir(mask_path)

    counter = 0
    total = 0
    for i in range(15):

        atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
                'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']

        for j in range(i * 2000, (i + 1) * 2000):

            mask = np.zeros((512, 512))

            for l, att in enumerate(atts, 1):
                total += 1
                file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png'])
                path = osp.join(face_sep_mask, str(i), file_name)

                if os.path.exists(path):
                    counter += 1
                    sep_mask = np.array(Image.open(path).convert('P'))

                    mask[sep_mask == 225] = l
            if image_size != 512:
                mask = cv2.resize(mask,(image_size,image_size),interpolation=cv2.INTER_NEAREST)
            cv2.imwrite('{}/{}.png'.format(mask_path, j), mask)
            print(j)

    print(counter, total)