lapstyle_dataset.py 2.9 KB
Newer Older
W
wangna11BD 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import numpy as np
from PIL import Image
import paddle
import paddle.vision.transforms as T
from paddle.io import Dataset

from .builder import DATASETS

logger = logging.getLogger(__name__)


def data_transform(crop_size):
    transform_list = [T.RandomCrop(crop_size)]
    return T.Compose(transform_list)


@DATASETS.register()
class LapStyleDataset(Dataset):
    """
    coco2017 dataset for LapStyle model
    """
    def __init__(self, content_root, style_root, load_size, crop_size):
        super(LapStyleDataset, self).__init__()
        self.content_root = content_root
        self.paths = os.listdir(self.content_root)
        self.style_root = style_root
        self.load_size = load_size
        self.crop_size = crop_size
        self.transform = data_transform(self.crop_size)

    def __getitem__(self, index):
        """Get training sample

        return:
            ci: content image with shape [C,W,H],
            si: style image with shape [C,W,H],
            ci_path: str
        """
        path = self.paths[index]
        content_img = Image.open(os.path.join(self.content_root,
                                              path)).convert('RGB')
        content_img = content_img.resize((self.load_size, self.load_size),
                                         Image.BILINEAR)
        content_img = np.array(content_img)
        style_img = Image.open(self.style_root).convert('RGB')
        style_img = style_img.resize((self.load_size, self.load_size),
                                     Image.BILINEAR)
        style_img = np.array(style_img)
        content_img = self.transform(content_img)
        style_img = self.transform(style_img)
        content_img = self.img(content_img)
        style_img = self.img(style_img)
        return {'ci': content_img, 'si': style_img, 'ci_path': path}

    def img(self, img):
        """make image with [0,255] and HWC to [0,1] and CHW

        return:
            img: image with shape [3,W,H] and value [0, 1].
        """
        # [0,255] to [0,1]
        img = img.astype(np.float32) / 255.
        # some images have 4 channels
        if img.shape[2] > 3:
            img = img[:, :, :3]
        # HWC to CHW
        img = np.transpose(img, (2, 0, 1)).astype('float32')
        return img

    def __len__(self):
        return len(self.paths)

    def name(self):
        return 'LapStyleDataset'