4.4 KB
Newer Older
FutureSI 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import numpy as np
from PIL import Image
import paddle
import as T
from import Dataset
import cv2
import random

from .builder import DATASETS

logger = logging.getLogger(__name__)

def data_transform(img, resize_w, resize_h, load_size=286, pos=[0, 0, 256, 256], flip=True, is_image=True):
    if is_image:
        resized = img.resize((resize_w, resize_h), Image.BICUBIC)
        resized = img.resize((resize_w, resize_h), Image.NEAREST)
    croped = resized.crop((pos[0], pos[1], pos[2], pos[3]))
    fliped = ImageOps.mirror(croped) if flip else croped
    fliped = np.array(fliped) # transform to numpy array
    expanded = np.expand_dims(fliped, 2) if len(fliped.shape) < 3 else fliped
    transposed = np.transpose(expanded, (2, 0, 1)).astype('float32')
    if is_image:
        normalized = transposed / 255. * 2. - 1.
        normalized = transposed
    return normalized

class PhotoPenDataset(Dataset):
    def __init__(self, content_root, load_size, crop_size):
        super(PhotoPenDataset, self).__init__()
        inst_dir = os.path.join(content_root, 'train_inst')
        _, _, inst_list = next(os.walk(inst_dir))
        self.inst_list = np.sort(inst_list)
        self.content_root = content_root
        self.load_size = load_size
        self.crop_size = crop_size

    def __getitem__(self, idx):
        ins =, 'train_inst', self.inst_list[idx]))
        img =, 'train_img', self.inst_list[idx].replace(".png", ".jpg")))
        img = img.convert('RGB')

        w, h = img.size
        resize_w, resize_h = 0, 0
        if w < h:
            resize_w, resize_h = self.load_size, int(h * self.load_size / w)
            resize_w, resize_h = int(w * self.load_size / h), self.load_size
        left = random.randint(0, resize_w - self.crop_size)
        top = random.randint(0, resize_h - self.crop_size)
        flip = False
        img = data_transform(img, resize_w, resize_h, load_size=self.load_size, 
            pos=[left, top, left + self.crop_size, top + self.crop_size], flip=flip, is_image=True)
        ins = data_transform(ins, resize_w, resize_h, load_size=self.load_size, 
            pos=[left, top, left + self.crop_size, top + self.crop_size], flip=flip, is_image=False)
        return {'img': img, 'ins': ins, 'img_path': self.inst_list[idx]}

    def __len__(self):
        return len(self.inst_list)
    def name(self):
        return 'PhotoPenDataset'

class PhotoPenDataset_test(Dataset):
    def __init__(self, content_root, load_size, crop_size):
        super(PhotoPenDataset_test, self).__init__()
        inst_dir = os.path.join(content_root, 'test_inst')
        _, _, inst_list = next(os.walk(inst_dir))
        self.inst_list = np.sort(inst_list)
        self.content_root = content_root
        self.load_size = load_size
        self.crop_size = crop_size

    def __getitem__(self, idx):
        ins =, 'test_inst', self.inst_list[idx]))

        w, h = ins.size
        resize_w, resize_h = 0, 0
        if w < h:
            resize_w, resize_h = self.load_size, int(h * self.load_size / w)
            resize_w, resize_h = int(w * self.load_size / h), self.load_size
        left = random.randint(0, resize_w - self.crop_size)
        top = random.randint(0, resize_h - self.crop_size)
        flip = False
        ins = data_transform(ins, resize_w, resize_h, load_size=self.load_size, 
            pos=[left, top, left + self.crop_size, top + self.crop_size], flip=flip, is_image=False)
        return {'ins': ins, 'img_path': self.inst_list[idx]}

    def __len__(self):
        return len(self.inst_list)
    def name(self):
        return 'PhotoPenDataset'