infer.py 10.4 KB
Newer Older
L
lvmengsi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import functools
import os
from PIL import Image
import paddle.fluid as fluid
import paddle
import numpy as np
L
lvmengsi 已提交
26
import imageio
L
lvmengsi 已提交
27 28
import glob
from util.config import add_arguments, print_arguments
L
lvmengsi 已提交
29
from data_reader import celeba_reader_creator
L
lvmengsi 已提交
30
from util.utility import check_attribute_conflict
L
lvmengsi 已提交
31
import copy
L
lvmengsi 已提交
32 33 34 35 36

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('model_net',         str,   'cgan',            "The model used")
L
lvmengsi 已提交
37
add_arg('net_G',             str,   "resnet_9block",   "Choose the CycleGAN and Pix2pix generator's network, choose in [resnet_9block|resnet_6block|unet_128|unet_256]")
L
lvmengsi 已提交
38 39 40 41 42 43 44
add_arg('init_model',        str,   None,              "The init model file of directory.")
add_arg('output',            str,   "./infer_result",  "The directory the infer result to be saved to.")
add_arg('input_style',       str,   "A",               "The style of the input, A or B")
add_arg('norm_type',         str,   "batch_norm",      "Which normalization to used")
add_arg('use_gpu',           bool,  True,              "Whether to use GPU to train.")
add_arg('dropout',           bool,  False,             "Whether to use dropout")
add_arg('g_base_dims',       int,   64,                "Base channels in CycleGAN generator")
L
lvmengsi 已提交
45 46 47 48 49 50 51 52 53
add_arg('c_dim',             int,   13,                "the size of attrs")
add_arg('use_gru',           bool,  False,             "Whether to use GRU")
add_arg('crop_size',         int,   178,               "crop size")
add_arg('image_size',        int,   128,               "image size")
add_arg('selected_attrs',    str,
    "Bald,Bangs,Black_Hair,Blond_Hair,Brown_Hair,Bushy_Eyebrows,Eyeglasses,Male,Mouth_Slightly_Open,Mustache,No_Beard,Pale_Skin,Young",
"the attributes we selected to change")
add_arg('batch_size',        int,   16,                "batch size when test")
add_arg('test_list',       str,   "./data/celeba/test_list_attr_celeba.txt",                "the test list file")
L
lvmengsi 已提交
54
add_arg('dataset_dir',       str,   "./data/celeba/",                "the dataset directory to be infered")
L
lvmengsi 已提交
55 56
add_arg('n_layers',        int,     5,      "default layers in generotor")
add_arg('gru_n_layers',    int,     4,       "default layers of GRU in generotor")
L
lvmengsi 已提交
57 58 59 60
# yapf: enable


def infer(args):
L
lvmengsi 已提交
61
    data_shape = [-1, 3, args.image_size, args.image_size]
L
lvmengsi 已提交
62
    input = fluid.layers.data(name='input', shape=data_shape, dtype='float32')
L
lvmengsi 已提交
63 64 65 66 67
    label_org_ = fluid.layers.data(
        name='label_org_', shape=[args.c_dim], dtype='float32')
    label_trg_ = fluid.layers.data(
        name='label_trg_', shape=[args.c_dim], dtype='float32')

L
lvmengsi 已提交
68
    model_name = 'net_G'
L
lvmengsi 已提交
69
    if args.model_net == 'CycleGAN':
L
lvmengsi 已提交
70 71
        from network.CycleGAN_network import CycleGAN_model
        model = CycleGAN_model()
L
lvmengsi 已提交
72
        if args.input_style == "A":
L
lvmengsi 已提交
73
            fake = model.network_G(input, name="GA", cfg=args)
L
lvmengsi 已提交
74
        elif args.input_style == "B":
L
lvmengsi 已提交
75
            fake = model.network_G(input, name="GB", cfg=args)
L
lvmengsi 已提交
76 77
        else:
            raise "Input with style [%s] is not supported." % args.input_style
Z
zhumanyu 已提交
78 79 80 81
    elif args.model_net == 'Pix2pix':
        from network.Pix2pix_network import Pix2pix_model
        model = Pix2pix_model()
        fake = model.network_G(input, "generator", cfg=args)
Z
zhumanyu 已提交
82 83 84 85
    elif args.model_net == 'StarGAN':
        from network.StarGAN_network import StarGAN_model
        model = StarGAN_model()
        fake = model.network_G(input, label_trg_, name="g_main", cfg=args)
L
lvmengsi 已提交
86 87 88 89
    elif args.model_net == 'STGAN':
        from network.STGAN_network import STGAN_model
        model = STGAN_model()
        fake, _ = model.network_G(
L
lvmengsi 已提交
90 91 92 93 94 95
            input,
            label_org_,
            label_trg_,
            cfg=args,
            name='generator',
            is_test=True)
L
lvmengsi 已提交
96 97 98 99
    elif args.model_net == 'AttGAN':
        from network.AttGAN_network import AttGAN_model
        model = AttGAN_model()
        fake, _ = model.network_G(
L
lvmengsi 已提交
100 101 102 103 104 105
            input,
            label_org_,
            label_trg_,
            cfg=args,
            name='generator',
            is_test=True)
L
lvmengsi 已提交
106
    else:
L
lvmengsi 已提交
107 108
        raise NotImplementedError("model_net {} is not support".format(
            args.model_net))
L
lvmengsi 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124

    # prepare environment
    place = fluid.CPUPlace()
    if args.use_gpu:
        place = fluid.CUDAPlace(0)
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
    for var in fluid.default_main_program().global_block().all_parameters():
        print(var.name)
    print(args.init_model + '/' + model_name)
    fluid.io.load_persistables(exe, args.init_model + "/" + model_name)
    print('load params done')

    if not os.path.exists(args.output):
        os.makedirs(args.output)

L
lvmengsi 已提交
125 126 127 128 129 130 131 132 133 134 135
    if args.model_net == 'AttGAN' or args.model_net == 'STGAN':
        test_reader = celeba_reader_creator(
            image_dir=args.dataset_dir,
            list_filename=args.test_list,
            batch_size=args.batch_size,
            drop_last=False,
            args=args)
        reader_test = test_reader.get_test_reader(
            args, shuffle=False, return_name=True)
        for data in zip(reader_test()):
            real_img, label_org, name = data[0]
L
lvmengsi 已提交
136
            attr_names = args.selected_attrs.split(',')
L
lvmengsi 已提交
137 138 139 140 141 142 143 144 145 146 147 148 149
            print("read {}".format(name))
            label_trg = copy.deepcopy(label_org)
            tensor_img = fluid.LoDTensor()
            tensor_label_org = fluid.LoDTensor()
            tensor_label_trg = fluid.LoDTensor()
            tensor_label_org_ = fluid.LoDTensor()
            tensor_label_trg_ = fluid.LoDTensor()
            tensor_img.set(real_img, place)
            tensor_label_org.set(label_org, place)
            real_img_temp = np.squeeze(real_img).transpose([0, 2, 3, 1])
            images = [real_img_temp]
            for i in range(args.c_dim):
                label_trg_tmp = copy.deepcopy(label_trg)
L
lvmengsi 已提交
150
                for j in range(len(label_org)):
L
lvmengsi 已提交
151
                    label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
L
lvmengsi 已提交
152 153
                    label_trg_tmp = check_attribute_conflict(
                        label_trg_tmp, attr_names[i], attr_names)
L
lvmengsi 已提交
154 155 156
                label_trg_ = list(
                    map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
                for j in range(len(label_org)):
L
lvmengsi 已提交
157 158 159 160 161 162 163 164 165
                    label_trg_[j][i] = label_trg_[j][i] * 2.0
                tensor_label_org_.set(label_org, place)
                tensor_label_trg.set(label_trg, place)
                tensor_label_trg_.set(label_trg_, place)
                out = exe.run(feed={
                    "input": tensor_img,
                    "label_org_": tensor_label_org_,
                    "label_trg_": tensor_label_trg_
                },
L
lvmengsi 已提交
166
                              fetch_list=[fake.name])
L
lvmengsi 已提交
167 168 169 170
                fake_temp = np.squeeze(out[0]).transpose([0, 2, 3, 1])
                images.append(fake_temp)
            images_concat = np.concatenate(images, 1)
            images_concat = np.concatenate(images_concat, 1)
L
lvmengsi 已提交
171
            imageio.imwrite(args.output + "/fake_img_" + name[0], (
L
lvmengsi 已提交
172
                (images_concat + 1) * 127.5).astype(np.uint8))
Z
zhumanyu 已提交
173 174 175 176 177 178 179 180 181 182 183
    elif args.model_net == 'StarGAN':
        test_reader = celeba_reader_creator(
            image_dir=args.dataset_dir,
            list_filename=args.test_list,
            batch_size=args.batch_size,
            drop_last=False,
            args=args)
        reader_test = test_reader.get_test_reader(
            args, shuffle=False, return_name=True)
        for data in zip(reader_test()):
            real_img, label_org, name = data[0]
L
lvmengsi 已提交
184
            print("read {}".format(name))
Z
zhumanyu 已提交
185 186 187 188
            tensor_img = fluid.LoDTensor()
            tensor_label_org = fluid.LoDTensor()
            tensor_img.set(real_img, place)
            tensor_label_org.set(label_org, place)
L
lvmengsi 已提交
189
            real_img_temp = np.squeeze(real_img).transpose([0, 2, 3, 1])
Z
zhumanyu 已提交
190
            images = [real_img_temp]
L
lvmengsi 已提交
191 192 193 194 195
            for i in range(args.c_dim):
                label_trg = np.zeros(
                    [len(label_org), args.c_dim]).astype("float32")
                for j in range(len(label_org)):
                    label_trg[j][i] = 1
Z
zhumanyu 已提交
196 197 198 199 200
                tensor_label_trg = fluid.LoDTensor()
                tensor_label_trg.set(label_trg, place)
                out = exe.run(
                    feed={"input": tensor_img,
                          "label_trg_": tensor_label_trg},
L
lvmengsi 已提交
201 202
                    fetch_list=[fake.name])
                fake_temp = np.squeeze(out[0]).transpose([0, 2, 3, 1])
Z
zhumanyu 已提交
203 204
                images.append(fake_temp)
            images_concat = np.concatenate(images, 1)
L
lvmengsi 已提交
205 206 207
            images_concat = np.concatenate(images_concat, 1)
            imageio.imwrite(args.output + "/fake_img_" + name[0], (
                (images_concat + 1) * 127.5).astype(np.uint8))
L
lvmengsi 已提交
208

L
lvmengsi 已提交
209 210
    elif args.model_net == 'Pix2pix' or args.model_net == 'CycleGAN':
        for file in glob.glob(args.dataset_dir):
L
lvmengsi 已提交
211 212 213 214 215 216 217 218 219 220
            print("read {}".format(file))
            image_name = os.path.basename(file)
            image = Image.open(file).convert('RGB')
            image = image.resize((256, 256), Image.BICUBIC)
            image = np.array(image).transpose([2, 0, 1]).astype('float32')
            image = image / 255.0
            image = (image - 0.5) / 0.5
            data = image[np.newaxis, :]
            tensor = fluid.LoDTensor()
            tensor.set(data, place)
L
lvmengsi 已提交
221

L
lvmengsi 已提交
222 223 224
            fake_temp = exe.run(fetch_list=[fake.name], feed={"input": tensor})
            fake_temp = np.squeeze(fake_temp[0]).transpose([1, 2, 0])
            input_temp = np.squeeze(data).transpose([1, 2, 0])
L
lvmengsi 已提交
225

L
lvmengsi 已提交
226
            imageio.imwrite(args.output + "/fake_" + image_name, (
L
lvmengsi 已提交
227 228 229 230
                (fake_temp + 1) * 127.5).astype(np.uint8))
    else:
        raise NotImplementedError("model_net {} is not support".format(
            args.model_net))
L
lvmengsi 已提交
231 232 233 234 235 236


if __name__ == "__main__":
    args = parser.parse_args()
    print_arguments(args)
    infer(args)