utility.py 16.9 KB
Newer Older
L
lvmengsi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
import os
import sys
import math
import distutils.util
import numpy as np
import inspect
import matplotlib
import six
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
L
lvmengsi 已提交
30
import copy
L
lvmengsi 已提交
31
from PIL import Image
L
lvmengsi 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51

img_dim = 28


def plot(gen_data):
    pad_dim = 1
    paded = pad_dim + img_dim
    gen_data = gen_data.reshape(gen_data.shape[0], img_dim, img_dim)
    n = int(math.ceil(math.sqrt(gen_data.shape[0])))
    gen_data = (np.pad(
        gen_data, [[0, n * n - gen_data.shape[0]], [pad_dim, 0], [pad_dim, 0]],
        'constant').reshape((n, n, paded, paded)).transpose((0, 2, 1, 3))
                .reshape((n * paded, n * paded)))
    fig = plt.figure(figsize=(8, 8))
    plt.axis('off')
    plt.imshow(gen_data, cmap='Greys_r', vmin=-1, vmax=1)
    return fig


def checkpoints(epoch, cfg, exe, trainer, name):
L
lvmengsi 已提交
52
    output_path = os.path.join(cfg.output, 'checkpoints', str(epoch))
L
lvmengsi 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    fluid.io.save_persistables(
        exe, os.path.join(output_path, name), main_program=trainer.program)
    print('save checkpoints {} to {}'.format(name, output_path))
    sys.stdout.flush()


def init_checkpoints(cfg, exe, trainer, name):
    assert os.path.exists(cfg.init_model), "{} cannot be found.".format(
        cfg.init_model)
    fluid.io.load_persistables(
        exe, os.path.join(cfg.init_model, name), main_program=trainer.program)
    print('load checkpoints {} {} DONE'.format(cfg.init_model, name))
    sys.stdout.flush()


L
lvmengsi 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
### the initialize checkpoint is one file named checkpoint.pdparams
def init_from_checkpoint(args, exe, trainer, name):
    if not os.path.exists(args.init_model):
        raise Warning("the checkpoint path does not exist.")
        return False

    fluid.io.load_persistables(
        executor=exe,
        dirname=os.path.join(args.init_model, name),
        main_program=trainer.program,
        filename="checkpoint.pdparams")

    print("finish initing model from checkpoint from %s" % (args.init_model))

    return True


### save the parameters of generator to one file
def save_param(args, exe, program, dirname, var_name="generator"):

    param_dir = os.path.join(args.output, 'infer_vars')

    if not os.path.exists(param_dir):
        os.makedirs(param_dir)

    def _name_has_generator(var):
        res = (fluid.io.is_parameter(var) and var.name.startswith(var_name))
        print(var.name, res)
        return res

    fluid.io.save_vars(
        exe,
        os.path.join(param_dir, dirname),
        main_program=program,
        predicate=_name_has_generator,
        filename="params.pdparams")
    print("save parameters at %s" % (os.path.join(param_dir, dirname)))

    return True


### save the checkpoint to one file
def save_checkpoint(epoch, args, exe, trainer, dirname):

    checkpoint_dir = os.path.join(args.output, 'checkpoints', str(epoch))

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    fluid.io.save_persistables(
        exe,
        os.path.join(checkpoint_dir, dirname),
        main_program=trainer.program,
        filename="checkpoint.pdparams")

    print("save checkpoint at %s" % (os.path.join(checkpoint_dir, dirname)))

    return True


Z
zhumanyu 已提交
130 131 132 133 134 135 136
def save_test_image(epoch,
                    cfg,
                    exe,
                    place,
                    test_program,
                    g_trainer,
                    A_test_reader,
L
lvmengsi 已提交
137 138 139
                    B_test_reader=None,
                    A_id2name=None,
                    B_id2name=None):
L
lvmengsi 已提交
140
    out_path = os.path.join(cfg.output, 'test')
L
lvmengsi 已提交
141 142
    if not os.path.exists(out_path):
        os.makedirs(out_path)
Z
zhumanyu 已提交
143
    if cfg.model_net == "Pix2pix":
L
lvmengsi 已提交
144 145 146 147 148 149 150
        for data in A_test_reader():
            A_data, B_data, image_name = data[0]['input_A'], data[0][
                'input_B'], data[0]['image_name']
            fake_B_temp = exe.run(test_program,
                                  fetch_list=[g_trainer.fake_B],
                                  feed={"input_A": A_data,
                                        "input_B": B_data})
Z
zhumanyu 已提交
151
            fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0])
L
lvmengsi 已提交
152
            input_A_temp = np.squeeze(np.array(A_data)[0]).transpose([1, 2, 0])
153
            input_B_temp = np.squeeze(np.array(B_data)[0]).transpose([1, 2, 0])
L
lvmengsi 已提交
154 155 156 157 158 159 160

            fakeB_name = "fakeB_" + str(epoch) + "_" + A_id2name[np.array(
                image_name).astype('int32')[0]]
            inputA_name = "inputA_" + str(epoch) + "_" + A_id2name[np.array(
                image_name).astype('int32')[0]]
            inputB_name = "inputB_" + str(epoch) + "_" + A_id2name[np.array(
                image_name).astype('int32')[0]]
L
lvmengsi 已提交
161 162 163 164 165 166 167 168 169 170 171 172

            res_fakeB = Image.fromarray(((fake_B_temp + 1) * 127.5).astype(
                np.uint8))
            res_fakeB.save(os.path.join(out_path, fakeB_name))

            res_inputA = Image.fromarray(((input_A_temp + 1) * 127.5).astype(
                np.uint8))
            res_inputA.save(os.path.join(out_path, inputA_name))

            res_inputB = Image.fromarray(((input_B_temp + 1) * 127.5).astype(
                np.uint8))
            res_inputB.save(os.path.join(out_path, inputB_name))
Z
zhumanyu 已提交
173 174
    elif cfg.model_net == "SPADE":
        for data in A_test_reader():
L
lvmengsi 已提交
175 176
            data_A, data_B, data_C, name = data[0]['input_label'], data[0][
                'input_img'], data[0]['input_ins'], data[0]['image_name']
L
lvmengsi 已提交
177 178 179 180 181 182 183
            fake_B_temp = exe.run(test_program,
                                  fetch_list=[g_trainer.fake_B],
                                  feed={
                                      "input_label": data_A,
                                      "input_img": data_B,
                                      "input_ins": data_C
                                  })
Z
zhumanyu 已提交
184 185
            fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0])
            input_B_temp = np.squeeze(data_B[0]).transpose([1, 2, 0])
L
lvmengsi 已提交
186
            image_name = A_id2name[np.array(name).astype('int32')[0]]
Z
zhumanyu 已提交
187 188 189

            res_fakeB = Image.fromarray(((fake_B_temp + 1) * 127.5).astype(
                np.uint8))
L
lvmengsi 已提交
190
            res_fakeB.save(out_path + "/fakeB_" + str(epoch) + "_" + image_name)
Z
zhumanyu 已提交
191 192
            res_real = Image.fromarray(((input_B_temp + 1) * 127.5).astype(
                np.uint8))
L
lvmengsi 已提交
193
            res_real.save(out_path + "/real_" + str(epoch) + "_" + image_name)
Z
zhumanyu 已提交
194
    elif cfg.model_net == "StarGAN":
L
lvmengsi 已提交
195 196 197 198
        for data in A_test_reader():
            real_img, label_org, label_trg, image_name = data[0][
                'image_real'], data[0]['label_org'], data[0]['label_trg'], data[
                    0]['image_name']
L
lvmengsi 已提交
199
            attr_names = cfg.selected_attrs.split(',')
L
lvmengsi 已提交
200
            real_img_temp = save_batch_image(np.array(real_img))
Z
zhumanyu 已提交
201 202
            images = [real_img_temp]
            for i in range(cfg.c_dim):
L
lvmengsi 已提交
203 204
                label_trg_tmp = copy.deepcopy(np.array(label_org))
                for j in range(len(np.array(label_org))):
L
lvmengsi 已提交
205
                    label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
L
lvmengsi 已提交
206
                    np_label_trg = check_attribute_conflict(
L
lvmengsi 已提交
207
                        label_trg_tmp, attr_names[i], attr_names)
L
lvmengsi 已提交
208
                label_trg.set(np_label_trg, place)
Z
zhumanyu 已提交
209
                fake_temp, rec_temp = exe.run(
L
lvmengsi 已提交
210
                    test_program,
Z
zhumanyu 已提交
211
                    feed={
L
lvmengsi 已提交
212 213 214
                        "image_real": real_img,
                        "label_org": label_org,
                        "label_trg": label_trg
Z
zhumanyu 已提交
215 216
                    },
                    fetch_list=[g_trainer.fake_img, g_trainer.rec_img])
L
lvmengsi 已提交
217 218
                fake_temp = save_batch_image(fake_temp)
                rec_temp = save_batch_image(rec_temp)
Z
zhumanyu 已提交
219 220 221
                images.append(fake_temp)
                images.append(rec_temp)
            images_concat = np.concatenate(images, 1)
L
lvmengsi 已提交
222
            if len(np.array(label_org)) > 1:
L
lvmengsi 已提交
223
                images_concat = np.concatenate(images_concat, 1)
L
lvmengsi 已提交
224 225
            image_name_save = "fake_img" + str(epoch) + "_" + str(
                np.array(image_name)[0].astype('int32')) + '.jpg'
L
lvmengsi 已提交
226 227 228 229 230

            res = Image.fromarray(((images_concat + 1) * 127.5).astype(
                np.uint8))
            res.save(os.path.join(out_path, image_name_save))

Z
zhumanyu 已提交
231
    elif cfg.model_net == 'AttGAN' or cfg.model_net == 'STGAN':
L
lvmengsi 已提交
232 233 234 235
        for data in A_test_reader():
            real_img, label_org, label_trg, image_name = data[0][
                'image_real'], data[0]['label_org'], data[0]['label_trg'], data[
                    0]['image_name']
L
lvmengsi 已提交
236
            attr_names = cfg.selected_attrs.split(',')
L
lvmengsi 已提交
237
            real_img_temp = save_batch_image(np.array(real_img))
Z
zhumanyu 已提交
238 239
            images = [real_img_temp]
            for i in range(cfg.c_dim):
L
lvmengsi 已提交
240
                label_trg_tmp = copy.deepcopy(np.array(label_trg))
Z
zhumanyu 已提交
241

L
lvmengsi 已提交
242
                for j in range(len(label_trg_tmp)):
Z
zhumanyu 已提交
243
                    label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
L
lvmengsi 已提交
244 245
                    label_trg_tmp = check_attribute_conflict(
                        label_trg_tmp, attr_names[i], attr_names)
Z
zhumanyu 已提交
246

L
lvmengsi 已提交
247 248 249
                label_org_tmp = list(
                    map(lambda x: ((x * 2) - 1) * 0.5, np.array(label_org)))
                label_trg_tmp = list(
250
                    map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
Z
zhumanyu 已提交
251

L
lvmengsi 已提交
252
                if cfg.model_net == 'AttGAN':
L
lvmengsi 已提交
253 254 255 256 257 258 259
                    for k in range(len(label_trg_tmp)):
                        label_trg_tmp[k][i] = label_trg_tmp[k][i] * 2.0
                tensor_label_org_ = fluid.LoDTensor()
                tensor_label_org_.set(label_org_tmp, place)
                tensor_label_trg_ = fluid.LoDTensor()
                tensor_label_trg_.set(label_trg_tmp, place)

Z
zhumanyu 已提交
260 261
                out = exe.run(test_program,
                              feed={
L
lvmengsi 已提交
262 263
                                  "image_real": real_img,
                                  "label_org": label_org,
Z
zhumanyu 已提交
264
                                  "label_org_": tensor_label_org_,
L
lvmengsi 已提交
265
                                  "label_trg": label_trg,
Z
zhumanyu 已提交
266 267 268
                                  "label_trg_": tensor_label_trg_
                              },
                              fetch_list=[g_trainer.fake_img])
L
lvmengsi 已提交
269
                fake_temp = save_batch_image(out[0])
Z
zhumanyu 已提交
270 271
                images.append(fake_temp)
            images_concat = np.concatenate(images, 1)
L
lvmengsi 已提交
272
            if len(label_trg_tmp) > 1:
L
lvmengsi 已提交
273
                images_concat = np.concatenate(images_concat, 1)
L
lvmengsi 已提交
274 275
            image_name_save = 'fake_img_' + str(epoch) + '_' + str(
                np.array(image_name)[0].astype('int32')) + '.jpg'
L
lvmengsi 已提交
276 277 278 279

            res = Image.fromarray(((images_concat + 1) * 127.5).astype(
                np.uint8))
            res.save(os.path.join(out_path, image_name_save))
Z
zhumanyu 已提交
280 281 282

    else:
        for data_A, data_B in zip(A_test_reader(), B_test_reader()):
L
lvmengsi 已提交
283 284
            A_data, A_name = data_A[0]['input_A'], data_A[0]['A_image_name']
            B_data, B_name = data_B[0]['input_B'], data_B[0]['B_image_name']
Z
zhumanyu 已提交
285 286 287 288 289 290
            fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = exe.run(
                test_program,
                fetch_list=[
                    g_trainer.fake_A, g_trainer.fake_B, g_trainer.cyc_A,
                    g_trainer.cyc_B
                ],
L
lvmengsi 已提交
291 292
                feed={"input_A": A_data,
                      "input_B": B_data})
Z
zhumanyu 已提交
293 294 295 296
            fake_A_temp = np.squeeze(fake_A_temp[0]).transpose([1, 2, 0])
            fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0])
            cyc_A_temp = np.squeeze(cyc_A_temp[0]).transpose([1, 2, 0])
            cyc_B_temp = np.squeeze(cyc_B_temp[0]).transpose([1, 2, 0])
L
lvmengsi 已提交
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
            input_A_temp = np.squeeze(np.array(A_data)).transpose([1, 2, 0])
            input_B_temp = np.squeeze(np.array(B_data)).transpose([1, 2, 0])

            fakeA_name = "fakeA_" + str(epoch) + "_" + A_id2name[np.array(
                A_name).astype('int32')[0]]
            fakeB_name = "fakeB_" + str(epoch) + "_" + B_id2name[np.array(
                B_name).astype('int32')[0]]
            inputA_name = "inputA_" + str(epoch) + "_" + A_id2name[np.array(
                A_name).astype('int32')[0]]
            inputB_name = "inputB_" + str(epoch) + "_" + B_id2name[np.array(
                B_name).astype('int32')[0]]
            cycA_name = "cycA_" + str(epoch) + "_" + A_id2name[np.array(
                A_name).astype('int32')[0]]
            cycB_name = "cycB_" + str(epoch) + "_" + B_id2name[np.array(
                B_name).astype('int32')[0]]
L
lvmengsi 已提交
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335

            res_fakeB = Image.fromarray(((fake_B_temp + 1) * 127.5).astype(
                np.uint8))
            res_fakeB.save(os.path.join(out_path, fakeB_name))

            res_fakeA = Image.fromarray(((fake_A_temp + 1) * 127.5).astype(
                np.uint8))
            res_fakeA.save(os.path.join(out_path, fakeA_name))

            res_cycA = Image.fromarray(((cyc_A_temp + 1) * 127.5).astype(
                np.uint8))
            res_cycA.save(os.path.join(out_path, cycA_name))

            res_cycB = Image.fromarray(((cyc_B_temp + 1) * 127.5).astype(
                np.uint8))
            res_cycB.save(os.path.join(out_path, cycB_name))

            res_inputA = Image.fromarray(((input_A_temp + 1) * 127.5).astype(
                np.uint8))
            res_inputA.save(os.path.join(out_path, inputA_name))

            res_inputB = Image.fromarray(((input_B_temp + 1) * 127.5).astype(
                np.uint8))
            res_inputB.save(os.path.join(out_path, inputB_name))
L
lvmengsi 已提交
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357


class ImagePool(object):
    def __init__(self, pool_size=50):
        self.pool = []
        self.count = 0
        self.pool_size = pool_size

    def pool_image(self, image):
        if self.count < self.pool_size:
            self.pool.append(image)
            self.count += 1
            return image
        else:
            p = np.random.rand()
            if p > 0.5:
                random_id = np.random.randint(0, self.pool_size - 1)
                temp = self.pool[random_id]
                self.pool[random_id] = image
                return temp
            else:
                return image
L
lvmengsi 已提交
358 359 360


def check_attribute_conflict(label_batch, attr, attrs):
L
lvmengsi 已提交
361 362
    ''' Based on https://github.com/LynnHo/AttGAN-Tensorflow'''

L
lvmengsi 已提交
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383
    def _set(label, value, attr):
        if attr in attrs:
            label[attrs.index(attr)] = value

    attr_id = attrs.index(attr)
    for label in label_batch:
        if attr in ['Bald', 'Receding_Hairline'] and attrs[attr_id] != 0:
            _set(label, 0, 'Bangs')
        elif attr == 'Bangs' and attrs[attr_id] != 0:
            _set(label, 0, 'Bald')
            _set(label, 0, 'Receding_Hairline')
        elif attr in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'
                      ] and attrs[attr_id] != 0:
            for a in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
                if a != attr:
                    _set(label, 0, a)
        elif attr in ['Straight_Hair', 'Wavy_Hair'] and attrs[attr_id] != 0:
            for a in ['Straight_Hair', 'Wavy_Hair']:
                if a != attr:
                    _set(label, 0, a)
    return label_batch
L
lvmengsi 已提交
384 385


L
lvmengsi 已提交
386
def save_batch_image(img):
L
lvmengsi 已提交
387
    #if img.shape[0] == 1:
L
lvmengsi 已提交
388 389 390 391 392 393 394
    if len(img) == 1:
        res_img = np.squeeze(img).transpose([1, 2, 0])
    else:
        res_img = np.squeeze(img).transpose([0, 2, 3, 1])
    return res_img


L
lvmengsi 已提交
395
def check_gpu(use_gpu):
L
lvmengsi 已提交
396
    """
L
lvmengsi 已提交
397 398 399
     Log error and exit when set use_gpu=true in paddlepaddle
     cpu version.
     """
L
lvmengsi 已提交
400 401 402 403 404
    err = "Config use_gpu cannot be set as true while you are " \
          "using paddlepaddle cpu version ! \nPlease try: \n" \
          "\t1. Install paddlepaddle-gpu to run model on GPU \n" \
          "\t2. Set use_gpu as false in config file to run " \
          "model on CPU"
L
lvmengsi 已提交
405

L
lvmengsi 已提交
406 407
    try:
        if use_gpu and not fluid.is_compiled_with_cuda():
L
lvmengsi 已提交
408
            print(err)
L
lvmengsi 已提交
409 410 411
            sys.exit(1)
    except Exception as e:
        pass
L
lvmengsi 已提交
412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427


def check_version():
    """
    Log error and exit when the installed version of paddlepaddle is
    not satisfied.
    """
    err = "PaddlePaddle version 1.6 or higher is required, " \
          "or a suitable develop version is satisfied as well. \n" \
          "Please make sure the version is good with your code." \

    try:
        fluid.require_version('1.6.0')
    except Exception as e:
        print(err)
        sys.exit(1)
C
ceci3 已提交
428 429 430 431 432 433 434 435 436

def get_device_num(args):
    if args.use_gpu:
        gpus = os.environ.get("CUDA_VISIBLE_DEVICES", 1)
        gpu_num = len(gpus.split(','))
        return gpu_num
    else:
        cpu_num = os.environ.get("CPU_NUM", 1)
        return int(cpu_num)