infer.py 3.3 KB
Newer Older
Q
qingqing01 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Q
qingqing01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import glob
import numpy as np
import argparse

from PIL import Image
from scipy.misc import imsave

import paddle.fluid as fluid
28
from hapi.model import Model, Input, set_device
Q
qingqing01 已提交
29

30
from check import check_gpu, check_version
Q
qingqing01 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
from cyclegan import Generator, GeneratorCombine


def main():
    place = set_device(FLAGS.device)
    fluid.enable_dygraph(place) if FLAGS.dynamic else None

    # Generators
    g_AB = Generator()
    g_BA = Generator()
    g = GeneratorCombine(g_AB, g_BA, is_train=False)

    im_shape = [-1, 3, 256, 256]
    input_A = Input(im_shape, 'float32', 'input_A')
    input_B = Input(im_shape, 'float32', 'input_B')
46
    g.prepare(inputs=[input_A, input_B], device=FLAGS.device)
Q
qingqing01 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
    g.load(FLAGS.init_model, skip_mismatch=True, reset_optimizer=True)

    out_path = FLAGS.output + "/single"
    if not os.path.exists(out_path):
        os.makedirs(out_path)
    for f in glob.glob(FLAGS.input):
        image_name = os.path.basename(f)
        image = Image.open(f).convert('RGB')
        image = image.resize((256, 256), Image.BICUBIC)
        image = np.array(image) / 127.5 - 1

        image = image[:, :, 0:3].astype("float32")
        data = image.transpose([2, 0, 1])[np.newaxis, :]

        if FLAGS.input_style == "A":
62
            _, fake, _, _ = g.test_batch([data, data])
Q
qingqing01 已提交
63 64

        if FLAGS.input_style == "B":
65
            fake, _, _, _ = g.test_batch([data, data])
Q
qingqing01 已提交
66 67 68 69 70 71 72 73 74 75 76

        fake = np.squeeze(fake[0]).transpose([1, 2, 0])

        opath = "{}/fake{}{}".format(out_path, FLAGS.input_style, image_name)
        imsave(opath, ((fake + 1) * 127.5).astype(np.uint8))
        print("transfer {} to {}".format(f, opath))


if __name__ == "__main__":
    parser = argparse.ArgumentParser("CycleGAN inference")
    parser.add_argument(
77
        "-d", "--dynamic", action='store_true', help="Enable dygraph mode")
Q
qingqing01 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
    parser.add_argument(
        "-p",
        "--device",
        type=str,
        default='gpu',
        help="device to use, gpu or cpu")
    parser.add_argument(
        "-i",
        "--input",
        type=str,
        default='./image/testA/123_A.jpg',
        help="input image")
    parser.add_argument(
        "-o",
        '--output',
        type=str,
        default='output',
        help="The test result to be saved to.")
    parser.add_argument(
        "-m",
        "--init_model",
        type=str,
Q
qingqing01 已提交
100
        default='checkpoint/199',
Q
qingqing01 已提交
101 102 103 104
        help="The init model file of directory.")
    parser.add_argument(
        "-s", "--input_style", type=str, default='A', help="A or B")
    FLAGS = parser.parse_args()
Q
qingqing01 已提交
105
    print(FLAGS)
Q
qingqing01 已提交
106 107
    check_gpu(str.lower(FLAGS.device) == 'gpu')
    check_version()
Q
qingqing01 已提交
108
    main()