infer.py 2.6 KB
Newer Older
X
xiaoting 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import sys
import paddle
import argparse
import functools
import time
import numpy as np
import glob
from PIL import Image
from scipy.misc import imsave
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
from paddle.fluid import core
import data_reader
from utility import add_arguments, print_arguments, ImagePool
from trainer import *
from paddle.fluid.dygraph.base import to_variable
import six
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)


# yapf: disable
add_arg('input',             str,   "123_A.jpg",      "input image")
add_arg('output',            str,   "./output_0", "The directory the model and the test result to be saved to.")
add_arg('init_model',        str,   './G/150',       "The init model file of directory.")
add_arg('input_style',       str,   "A",        "A or B")
def infer():
    with fluid.dygraph.guard():
        data_shape = [-1,3,256,256]
       
        out_path = args.output + "/single" + "/" + str(args.input)
        if not os.path.exists(out_path):
            os.makedirs(out_path)
        cycle_gan = Cycle_Gan("cycle_gan")
        save_dir = args.init_model 
        restore = fluid.dygraph.load_persistables(save_dir)
        cycle_gan.load_dict(restore)
        cycle_gan.eval()
        for file in glob.glob(args.input):
            print ("read %s" % file)
            image_name = os.path.basename(file)
            image = Image.open(file).convert('RGB')
            image = image.resize((256, 256), Image.BICUBIC)
            image = np.array(image) / 127.5 - 1

            image = image[:, :, 0:3].astype("float32")
            data = image.transpose([2, 0, 1])[np.newaxis,:]

            
            data_A_tmp = to_variable(data)

            fake_A_temp,fake_B_temp,cyc_A_temp,cyc_B_temp,g_A_loss,g_B_loss,idt_loss_A,idt_loss_B,cyc_A_loss,cyc_B_loss,g_loss = cycle_gan(data_A_tmp,data_A_tmp,True,False,False)
       
            fake_A_temp = np.squeeze(fake_A_temp.numpy()[0]).transpose([1, 2, 0])
            fake_B_temp = np.squeeze(fake_B_temp.numpy()[0]).transpose([1, 2, 0])

            if args.input_style == "A":
                imsave(out_path + "/fakeB_" + image_name, (
                    (fake_B_temp + 1) * 127.5).astype(np.uint8))
            if args.input_style == "B":
                imsave(out_path + "/fakeA_" + image_name, (
                    (fake_A_temp + 1) * 127.5).astype(np.uint8))


if __name__ == "__main__":
    args = parser.parse_args()
    print_arguments(args)
    infer()