infer.py 2.6 KB
Newer Older
X
xiaoting 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import sys
import paddle
import argparse
import functools
import time
import numpy as np
import glob
from PIL import Image
from scipy.misc import imsave
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
from paddle.fluid import core
import data_reader
from utility import add_arguments, print_arguments, ImagePool
from trainer import *
from paddle.fluid.dygraph.base import to_variable
import six
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)


# yapf: disable
X
xiaoting 已提交
28
add_arg('input',             str,   "./image/testA/123_A.jpg",      "input image")
X
xiaoting 已提交
29
add_arg('output',            str,   "./output_0", "The directory the model and the test result to be saved to.")
X
xiaoting 已提交
30
add_arg('init_model',        str,   './output_0/checkpoints/0',       "The init model file of directory.")
X
xiaoting 已提交
31 32 33 34 35
add_arg('input_style',       str,   "A",        "A or B")
def infer():
    with fluid.dygraph.guard():
        data_shape = [-1,3,256,256]
       
X
xiaoting 已提交
36
        out_path = args.output + "/single"
X
xiaoting 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
        if not os.path.exists(out_path):
            os.makedirs(out_path)
        cycle_gan = Cycle_Gan("cycle_gan")
        save_dir = args.init_model 
        restore = fluid.dygraph.load_persistables(save_dir)
        cycle_gan.load_dict(restore)
        cycle_gan.eval()
        for file in glob.glob(args.input):
            print ("read %s" % file)
            image_name = os.path.basename(file)
            image = Image.open(file).convert('RGB')
            image = image.resize((256, 256), Image.BICUBIC)
            image = np.array(image) / 127.5 - 1

            image = image[:, :, 0:3].astype("float32")
            data = image.transpose([2, 0, 1])[np.newaxis,:]

            
            data_A_tmp = to_variable(data)

            fake_A_temp,fake_B_temp,cyc_A_temp,cyc_B_temp,g_A_loss,g_B_loss,idt_loss_A,idt_loss_B,cyc_A_loss,cyc_B_loss,g_loss = cycle_gan(data_A_tmp,data_A_tmp,True,False,False)
       
            fake_A_temp = np.squeeze(fake_A_temp.numpy()[0]).transpose([1, 2, 0])
            fake_B_temp = np.squeeze(fake_B_temp.numpy()[0]).transpose([1, 2, 0])

            if args.input_style == "A":
                imsave(out_path + "/fakeB_" + image_name, (
                    (fake_B_temp + 1) * 127.5).astype(np.uint8))
            if args.input_style == "B":
                imsave(out_path + "/fakeA_" + image_name, (
                    (fake_A_temp + 1) * 127.5).astype(np.uint8))


if __name__ == "__main__":
    args = parser.parse_args()
    print_arguments(args)
    infer()