infer.py 17.2 KB
Newer Older
L
lvmengsi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import functools
import os
from PIL import Image
import paddle.fluid as fluid
import paddle
import numpy as np
L
lvmengsi 已提交
26
import imageio
L
lvmengsi 已提交
27 28
import glob
from util.config import add_arguments, print_arguments
L
lvmengsi 已提交
29
from data_reader import celeba_reader_creator, reader_creator, triplex_reader_creator
L
lvmengsi 已提交
30
from util.utility import check_attribute_conflict, check_gpu, save_batch_image, check_version
L
lvmengsi 已提交
31
from util import utility
L
lvmengsi 已提交
32
import copy
L
lvmengsi 已提交
33

L
lvmengsi 已提交
34 35 36 37
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt

L
lvmengsi 已提交
38 39 40
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
L
lvmengsi 已提交
41
add_arg('model_net',         str,   'CGAN',            "The model used")
L
lvmengsi 已提交
42
add_arg('net_G',             str,   "resnet_9block",   "Choose the CycleGAN and Pix2pix generator's network, choose in [resnet_9block|resnet_6block|unet_128|unet_256]")
L
lvmengsi 已提交
43 44 45 46
add_arg('init_model',        str,   None,              "The init model file of directory.")
add_arg('output',            str,   "./infer_result",  "The directory the infer result to be saved to.")
add_arg('input_style',       str,   "A",               "The style of the input, A or B")
add_arg('norm_type',         str,   "batch_norm",      "Which normalization to used")
Z
zhumanyu 已提交
47
add_arg('crop_type',         str,   None,      "Which crop type to use")
L
lvmengsi 已提交
48 49 50
add_arg('use_gpu',           bool,  True,              "Whether to use GPU to train.")
add_arg('dropout',           bool,  False,             "Whether to use dropout")
add_arg('g_base_dims',       int,   64,                "Base channels in CycleGAN generator")
Z
zhumanyu 已提交
51
add_arg('ngf',       int,   64,                "Base channels in SPADE generator")
L
lvmengsi 已提交
52 53 54 55
add_arg('c_dim',             int,   13,                "the size of attrs")
add_arg('use_gru',           bool,  False,             "Whether to use GRU")
add_arg('crop_size',         int,   178,               "crop size")
add_arg('image_size',        int,   128,               "image size")
Z
zhumanyu 已提交
56 57 58 59
add_arg('load_height',        int,   128,               "image size")
add_arg('load_width',        int,   128,               "image size")
add_arg('crop_height',        int,   128,               "height of crop size")
add_arg('crop_width',        int,   128,               "width of crop size")
L
lvmengsi 已提交
60 61 62
add_arg('selected_attrs',    str,
    "Bald,Bangs,Black_Hair,Blond_Hair,Brown_Hair,Bushy_Eyebrows,Eyeglasses,Male,Mouth_Slightly_Open,Mustache,No_Beard,Pale_Skin,Young",
"the attributes we selected to change")
L
lvmengsi 已提交
63 64
add_arg('n_samples',        int,   16,                "batch size when test")
add_arg('test_list',         str,   "./data/celeba/list_attr_celeba.txt",                "the test list file")
L
lvmengsi 已提交
65
add_arg('dataset_dir',       str,   "./data/celeba/",                "the dataset directory to be infered")
L
lvmengsi 已提交
66 67 68
add_arg('n_layers',          int,   5,                 "default layers in generotor")
add_arg('gru_n_layers',      int,   4,                 "default layers of GRU in generotor")
add_arg('noise_size',        int,   100,               "the noise dimension")
Z
zhumanyu 已提交
69 70
add_arg('label_nc',        int,   36,               "label numbers of SPADE")
add_arg('no_instance', type=bool, default=False, help="Whether to use instance label.")
L
lvmengsi 已提交
71 72 73 74
# yapf: enable


def infer(args):
L
lvmengsi 已提交
75
    data_shape = [-1, 3, args.image_size, args.image_size]
L
lvmengsi 已提交
76
    input = fluid.layers.data(name='input', shape=data_shape, dtype='float32')
L
lvmengsi 已提交
77 78 79 80
    label_org_ = fluid.layers.data(
        name='label_org_', shape=[args.c_dim], dtype='float32')
    label_trg_ = fluid.layers.data(
        name='label_trg_', shape=[args.c_dim], dtype='float32')
L
lvmengsi 已提交
81 82
    image_name = fluid.layers.data(
        name='image_name', shape=[args.n_samples], dtype='int32')
L
lvmengsi 已提交
83

L
lvmengsi 已提交
84
    model_name = 'net_G'
L
lvmengsi 已提交
85

L
lvmengsi 已提交
86
    if args.model_net == 'CycleGAN':
L
lvmengsi 已提交
87
        loader = fluid.io.DataLoader.from_generator(
L
lvmengsi 已提交
88 89 90 91
            feed_list=[input, image_name],
            capacity=4,  ## batch_size * 4
            iterable=True,
            use_double_buffer=True)
L
lvmengsi 已提交
92

L
lvmengsi 已提交
93 94
        from network.CycleGAN_network import CycleGAN_model
        model = CycleGAN_model()
L
lvmengsi 已提交
95
        if args.input_style == "A":
L
lvmengsi 已提交
96
            fake = model.network_G(input, name="GA", cfg=args)
L
lvmengsi 已提交
97
        elif args.input_style == "B":
L
lvmengsi 已提交
98
            fake = model.network_G(input, name="GB", cfg=args)
L
lvmengsi 已提交
99 100
        else:
            raise "Input with style [%s] is not supported." % args.input_style
Z
zhumanyu 已提交
101
    elif args.model_net == 'Pix2pix':
L
lvmengsi 已提交
102
        loader = fluid.io.DataLoader.from_generator(
L
lvmengsi 已提交
103 104 105 106 107
            feed_list=[input, image_name],
            capacity=4,  ## batch_size * 4
            iterable=True,
            use_double_buffer=True)

Z
zhumanyu 已提交
108 109 110
        from network.Pix2pix_network import Pix2pix_model
        model = Pix2pix_model()
        fake = model.network_G(input, "generator", cfg=args)
Z
zhumanyu 已提交
111
    elif args.model_net == 'StarGAN':
L
lvmengsi 已提交
112 113 114 115 116 117 118

        py_reader = fluid.io.PyReader(
            feed_list=[input, label_org_, label_trg_, image_name],
            capacity=32,
            iterable=True,
            use_double_buffer=True)

Z
zhumanyu 已提交
119 120 121
        from network.StarGAN_network import StarGAN_model
        model = StarGAN_model()
        fake = model.network_G(input, label_trg_, name="g_main", cfg=args)
L
lvmengsi 已提交
122 123
    elif args.model_net == 'STGAN':
        from network.STGAN_network import STGAN_model
L
lvmengsi 已提交
124 125 126 127 128 129 130

        py_reader = fluid.io.PyReader(
            feed_list=[input, label_org_, label_trg_, image_name],
            capacity=32,
            iterable=True,
            use_double_buffer=True)

L
lvmengsi 已提交
131 132
        model = STGAN_model()
        fake, _ = model.network_G(
L
lvmengsi 已提交
133 134 135 136 137 138
            input,
            label_org_,
            label_trg_,
            cfg=args,
            name='generator',
            is_test=True)
L
lvmengsi 已提交
139 140
    elif args.model_net == 'AttGAN':
        from network.AttGAN_network import AttGAN_model
L
lvmengsi 已提交
141 142 143 144 145 146 147

        py_reader = fluid.io.PyReader(
            feed_list=[input, label_org_, label_trg_, image_name],
            capacity=32,
            iterable=True,
            use_double_buffer=True)

L
lvmengsi 已提交
148 149
        model = AttGAN_model()
        fake, _ = model.network_G(
L
lvmengsi 已提交
150 151 152 153 154 155
            input,
            label_org_,
            label_trg_,
            cfg=args,
            name='generator',
            is_test=True)
L
lvmengsi 已提交
156 157 158 159 160 161 162
    elif args.model_net == 'CGAN':
        noise = fluid.layers.data(
            name='noise', shape=[args.noise_size], dtype='float32')
        conditions = fluid.layers.data(
            name='conditions', shape=[1], dtype='float32')

        from network.CGAN_network import CGAN_model
L
lvmengsi 已提交
163
        model = CGAN_model(args.n_samples)
L
lvmengsi 已提交
164 165 166 167 168 169
        fake = model.network_G(noise, conditions, name="G")
    elif args.model_net == 'DCGAN':
        noise = fluid.layers.data(
            name='noise', shape=[args.noise_size], dtype='float32')

        from network.DCGAN_network import DCGAN_model
L
lvmengsi 已提交
170
        model = DCGAN_model(args.n_samples)
L
lvmengsi 已提交
171
        fake = model.network_G(noise, name="G")
Z
zhumanyu 已提交
172
    elif args.model_net == 'SPADE':
L
lvmengsi 已提交
173 174
        label_shape = [None, args.label_nc, args.crop_height, args.crop_width]
        spade_data_shape = [None, 1, args.crop_height, args.crop_width]
Z
zhumanyu 已提交
175 176
        from network.SPADE_network import SPADE_model
        model = SPADE_model()
L
lvmengsi 已提交
177
        input_label = fluid.data(
L
lvmengsi 已提交
178
            name='input_label', shape=label_shape, dtype='float32')
L
lvmengsi 已提交
179
        input_ins = fluid.data(
L
lvmengsi 已提交
180
            name='input_ins', shape=spade_data_shape, dtype='float32')
Z
zhumanyu 已提交
181 182
        input_ = fluid.layers.concat([input_label, input_ins], 1)
        fake = model.network_G(input_, "generator", cfg=args, is_test=True)
L
lvmengsi 已提交
183
    else:
L
lvmengsi 已提交
184 185
        raise NotImplementedError("model_net {} is not support".format(
            args.model_net))
L
lvmengsi 已提交
186

L
lvmengsi 已提交
187 188 189 190 191 192 193
    def _compute_start_end(image_name):
        image_name_start = np.array(image_name)[0].astype('int32')
        image_name_end = image_name_start + args.n_samples - 1
        image_name_save = str(np.array(image_name)[0].astype('int32')) + '.jpg'
        print("read {}.jpg ~ {}.jpg".format(image_name_start, image_name_end))
        return image_name_save

L
lvmengsi 已提交
194 195 196 197 198 199 200 201 202
    # prepare environment
    place = fluid.CPUPlace()
    if args.use_gpu:
        place = fluid.CUDAPlace(0)
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
    for var in fluid.default_main_program().global_block().all_parameters():
        print(var.name)
    print(args.init_model + '/' + model_name)
L
lvmengsi 已提交
203
    fluid.io.load_persistables(exe, os.path.join(args.init_model, model_name))
L
lvmengsi 已提交
204 205 206 207
    print('load params done')
    if not os.path.exists(args.output):
        os.makedirs(args.output)

L
lvmengsi 已提交
208 209
    attr_names = args.selected_attrs.split(',')

L
lvmengsi 已提交
210 211 212 213
    if args.model_net == 'AttGAN' or args.model_net == 'STGAN':
        test_reader = celeba_reader_creator(
            image_dir=args.dataset_dir,
            list_filename=args.test_list,
L
lvmengsi 已提交
214 215 216
            args=args,
            mode="VAL")
        reader_test = test_reader.make_reader(return_name=True)
L
lvmengsi 已提交
217 218 219 220 221 222 223 224
        py_reader.decorate_batch_generator(
            reader_test,
            places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
        for data in py_reader():
            real_img, label_org, label_trg, image_name = data[0]['input'], data[
                0]['label_org_'], data[0]['label_trg_'], data[0]['image_name']
            image_name_save = _compute_start_end(image_name)
            real_img_temp = save_batch_image(np.array(real_img))
L
lvmengsi 已提交
225 226
            images = [real_img_temp]
            for i in range(args.c_dim):
L
lvmengsi 已提交
227 228
                label_trg_tmp = copy.deepcopy(np.array(label_trg))
                for j in range(len(label_trg_tmp)):
L
lvmengsi 已提交
229
                    label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
L
lvmengsi 已提交
230 231
                    label_trg_tmp = check_attribute_conflict(
                        label_trg_tmp, attr_names[i], attr_names)
L
lvmengsi 已提交
232 233 234
                label_org_tmp = list(
                    map(lambda x: ((x * 2) - 1) * 0.5, np.array(label_org)))
                label_trg_tmp = list(
L
lvmengsi 已提交
235
                    map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
L
lvmengsi 已提交
236
                if args.model_net == 'AttGAN':
L
lvmengsi 已提交
237 238 239 240 241 242
                    for k in range(len(label_trg_tmp)):
                        label_trg_tmp[k][i] = label_trg_tmp[k][i] * 2.0
                tensor_label_org_ = fluid.LoDTensor()
                tensor_label_trg_ = fluid.LoDTensor()
                tensor_label_org_.set(label_org_tmp, place)
                tensor_label_trg_.set(label_trg_tmp, place)
L
lvmengsi 已提交
243
                out = exe.run(feed={
L
lvmengsi 已提交
244
                    "input": real_img,
L
lvmengsi 已提交
245 246 247
                    "label_org_": tensor_label_org_,
                    "label_trg_": tensor_label_trg_
                },
L
lvmengsi 已提交
248
                              fetch_list=[fake.name])
L
lvmengsi 已提交
249
                fake_temp = save_batch_image(out[0])
L
lvmengsi 已提交
250 251
                images.append(fake_temp)
            images_concat = np.concatenate(images, 1)
L
lvmengsi 已提交
252
            if len(np.array(label_org)) > 1:
L
lvmengsi 已提交
253
                images_concat = np.concatenate(images_concat, 1)
L
lvmengsi 已提交
254 255 256
            imageio.imwrite(
                os.path.join(args.output, "fake_img_" + image_name_save), (
                    (images_concat + 1) * 127.5).astype(np.uint8))
Z
zhumanyu 已提交
257 258 259 260
    elif args.model_net == 'StarGAN':
        test_reader = celeba_reader_creator(
            image_dir=args.dataset_dir,
            list_filename=args.test_list,
L
lvmengsi 已提交
261 262 263
            args=args,
            mode="VAL")
        reader_test = test_reader.make_reader(return_name=True)
L
lvmengsi 已提交
264 265 266 267 268 269 270 271
        py_reader.decorate_batch_generator(
            reader_test,
            places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
        for data in py_reader():
            real_img, label_org, label_trg, image_name = data[0]['input'], data[
                0]['label_org_'], data[0]['label_trg_'], data[0]['image_name']
            image_name_save = _compute_start_end(image_name)
            real_img_temp = save_batch_image(np.array(real_img))
Z
zhumanyu 已提交
272
            images = [real_img_temp]
L
lvmengsi 已提交
273
            for i in range(args.c_dim):
L
lvmengsi 已提交
274 275
                label_trg_tmp = copy.deepcopy(np.array(label_org))
                for j in range(len(np.array(label_org))):
L
lvmengsi 已提交
276
                    label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
L
lvmengsi 已提交
277
                    label_trg_tmp = check_attribute_conflict(
L
lvmengsi 已提交
278
                        label_trg_tmp, attr_names[i], attr_names)
L
lvmengsi 已提交
279 280
                tensor_label_trg_ = fluid.LoDTensor()
                tensor_label_trg_.set(label_trg_tmp, place)
Z
zhumanyu 已提交
281
                out = exe.run(
L
lvmengsi 已提交
282 283
                    feed={"input": real_img,
                          "label_trg_": tensor_label_trg_},
L
lvmengsi 已提交
284
                    fetch_list=[fake.name])
L
lvmengsi 已提交
285
                fake_temp = save_batch_image(out[0])
Z
zhumanyu 已提交
286 287
                images.append(fake_temp)
            images_concat = np.concatenate(images, 1)
L
lvmengsi 已提交
288
            if len(np.array(label_org)) > 1:
L
lvmengsi 已提交
289
                images_concat = np.concatenate(images_concat, 1)
L
lvmengsi 已提交
290 291 292
            imageio.imwrite(
                os.path.join(args.output, "fake_img_" + image_name_save), (
                    (images_concat + 1) * 127.5).astype(np.uint8))
L
lvmengsi 已提交
293

L
lvmengsi 已提交
294
    elif args.model_net == 'Pix2pix' or args.model_net == 'CycleGAN':
L
lvmengsi 已提交
295 296 297 298 299 300 301
        test_reader = reader_creator(
            image_dir=args.dataset_dir,
            list_filename=args.test_list,
            shuffle=False,
            batch_size=args.n_samples,
            mode="VAL")
        reader_test = test_reader.make_reader(args, return_name=True)
L
lvmengsi 已提交
302
        loader.set_batch_generator(
L
lvmengsi 已提交
303 304 305
            reader_test,
            places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
        id2name = test_reader.id2name
L
lvmengsi 已提交
306
        for data in loader():
L
lvmengsi 已提交
307 308 309 310 311
            real_img, image_name = data[0]['input'], data[0]['image_name']
            image_name = id2name[np.array(image_name).astype('int32')[0]]
            print("read: ", image_name)
            fake_temp = exe.run(fetch_list=[fake.name],
                                feed={"input": real_img})
L
lvmengsi 已提交
312
            fake_temp = np.squeeze(fake_temp[0]).transpose([1, 2, 0])
L
lvmengsi 已提交
313
            input_temp = np.squeeze(np.array(real_img)[0]).transpose([1, 2, 0])
L
lvmengsi 已提交
314

L
lvmengsi 已提交
315 316 317
            imageio.imwrite(
                os.path.join(args.output, "fake_" + image_name), (
                    (fake_temp + 1) * 127.5).astype(np.uint8))
Z
zhumanyu 已提交
318 319 320 321 322 323 324
    elif args.model_net == 'SPADE':
        test_reader = triplex_reader_creator(
            image_dir=args.dataset_dir,
            list_filename=args.test_list,
            shuffle=False,
            batch_size=1,
            mode="TEST")
L
lvmengsi 已提交
325
        id2name = test_reader.id2name
L
lvmengsi 已提交
326
        reader_test = test_reader.make_reader(args, return_name=True)
Z
zhumanyu 已提交
327 328
        for data in zip(reader_test()):
            data_A, data_B, data_C, name = data[0]
L
lvmengsi 已提交
329 330
            name = id2name[np.array(name).astype('int32')[0]]
            print("read: ", name)
Z
zhumanyu 已提交
331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
            tensor_A = fluid.LoDTensor()
            tensor_C = fluid.LoDTensor()
            tensor_A.set(data_A, place)
            tensor_C.set(data_C, place)
            fake_B_temp = exe.run(
                fetch_list=[fake.name],
                feed={"input_label": tensor_A,
                      "input_ins": tensor_C})
            fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0])
            input_B_temp = np.squeeze(data_B[0]).transpose([1, 2, 0])

            imageio.imwrite(args.output + "/fakeB_" + "_" + name, (
                (fake_B_temp + 1) * 127.5).astype(np.uint8))
            imageio.imwrite(args.output + "/real_" + "_" + name, (
                (input_B_temp + 1) * 127.5).astype(np.uint8))
L
lvmengsi 已提交
346 347 348 349

    elif args.model_net == 'CGAN':
        noise_data = np.random.uniform(
            low=-1.0, high=1.0,
L
lvmengsi 已提交
350
            size=[args.n_samples, args.noise_size]).astype('float32')
L
lvmengsi 已提交
351
        label = np.random.randint(
L
lvmengsi 已提交
352
            0, 9, size=[args.n_samples, 1]).astype('float32')
L
lvmengsi 已提交
353 354 355 356 357 358 359 360
        noise_tensor = fluid.LoDTensor()
        conditions_tensor = fluid.LoDTensor()
        noise_tensor.set(noise_data, place)
        conditions_tensor.set(label, place)
        fake_temp = exe.run(
            fetch_list=[fake.name],
            feed={"noise": noise_tensor,
                  "conditions": conditions_tensor})[0]
L
lvmengsi 已提交
361
        fake_image = np.reshape(fake_temp, (args.n_samples, -1))
L
lvmengsi 已提交
362 363

        fig = utility.plot(fake_image)
L
lvmengsi 已提交
364 365
        plt.savefig(
            os.path.join(args.output, 'fake_cgan.png'), bbox_inches='tight')
L
lvmengsi 已提交
366 367 368 369 370
        plt.close(fig)

    elif args.model_net == 'DCGAN':
        noise_data = np.random.uniform(
            low=-1.0, high=1.0,
L
lvmengsi 已提交
371
            size=[args.n_samples, args.noise_size]).astype('float32')
L
lvmengsi 已提交
372 373 374 375
        noise_tensor = fluid.LoDTensor()
        noise_tensor.set(noise_data, place)
        fake_temp = exe.run(fetch_list=[fake.name],
                            feed={"noise": noise_tensor})[0]
L
lvmengsi 已提交
376
        fake_image = np.reshape(fake_temp, (args.n_samples, -1))
L
lvmengsi 已提交
377 378

        fig = utility.plot(fake_image)
L
lvmengsi 已提交
379
        plt.savefig(
L
lvmengsi 已提交
380
            os.path.join(args.output, 'fake_dcgan.png'), bbox_inches='tight')
L
lvmengsi 已提交
381
        plt.close(fig)
L
lvmengsi 已提交
382 383 384
    else:
        raise NotImplementedError("model_net {} is not support".format(
            args.model_net))
L
lvmengsi 已提交
385 386 387 388 389


if __name__ == "__main__":
    args = parser.parse_args()
    print_arguments(args)
L
lvmengsi 已提交
390
    check_gpu(args.use_gpu)
L
lvmengsi 已提交
391
    check_version()
L
lvmengsi 已提交
392
    infer(args)