From 345b30e2b20c1503d9c3323107591a0d2d11e35b Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Thu, 10 Sep 2020 07:52:10 +0000 Subject: [PATCH] Update CycleGAN --- examples/cyclegan/cyclegan.py | 24 +++++++++------------ examples/cyclegan/infer.py | 20 ++++++++++++------ examples/cyclegan/test.py | 19 +++++++++++------ examples/cyclegan/train.py | 40 ++++++++++++++++++----------------- 4 files changed, 56 insertions(+), 47 deletions(-) diff --git a/examples/cyclegan/cyclegan.py b/examples/cyclegan/cyclegan.py index f6a2dae..5ace88b 100644 --- a/examples/cyclegan/cyclegan.py +++ b/examples/cyclegan/cyclegan.py @@ -18,9 +18,8 @@ from __future__ import print_function import numpy as np +import paddle import paddle.fluid as fluid -from paddle.incubate.hapi.model import Model -from paddle.incubate.hapi.loss import Loss from layers import ConvBN, DeConvBN @@ -133,7 +132,7 @@ class NLayerDiscriminator(fluid.dygraph.Layer): return y -class Generator(Model): +class Generator(paddle.nn.Layer): def __init__(self, input_channel=3): super(Generator, self).__init__() self.g = ResnetGenerator(input_channel) @@ -143,7 +142,7 @@ class Generator(Model): return fake -class GeneratorCombine(Model): +class GeneratorCombine(paddle.nn.Layer): def __init__(self, g_AB=None, g_BA=None, d_A=None, d_B=None, is_train=True): super(GeneratorCombine, self).__init__() @@ -177,16 +176,15 @@ class GeneratorCombine(Model): return input_A, input_B, fake_A, fake_B, cyc_A, cyc_B, idt_A, idt_B, valid_A, valid_B -class GLoss(Loss): +class GLoss(paddle.nn.Layer): def __init__(self, lambda_A=10., lambda_B=10., lambda_identity=0.5): super(GLoss, self).__init__() self.lambda_A = lambda_A self.lambda_B = lambda_B self.lambda_identity = lambda_identity - def forward(self, outputs, labels=None): - input_A, input_B, fake_A, fake_B, cyc_A, cyc_B, idt_A, idt_B, valid_A, valid_B = outputs - + def forward(self, input_A, input_B, fake_A, fake_B, cyc_A, cyc_B, idt_A, + idt_B, valid_A, valid_B): def mse(a, b): return fluid.layers.reduce_mean(fluid.layers.square(a - b)) @@ -211,7 +209,7 @@ class GLoss(Loss): return loss -class Discriminator(Model): +class Discriminator(paddle.nn.Layer): def __init__(self, input_channel=3): super(Discriminator, self).__init__() self.d = NLayerDiscriminator(input_channel) @@ -222,13 +220,11 @@ class Discriminator(Model): return pred_real, pred_fake -class DLoss(Loss): +class DLoss(paddle.nn.Layer): def __init__(self): super(DLoss, self).__init__() - def forward(self, inputs, labels=None): - pred_real, pred_fake = inputs - loss = fluid.layers.square(pred_fake) + fluid.layers.square(pred_real - - 1.) + def forward(self, real, fake): + loss = fluid.layers.square(fake) + fluid.layers.square(real - 1.) loss = fluid.layers.reduce_mean(loss / 2.0) return loss diff --git a/examples/cyclegan/infer.py b/examples/cyclegan/infer.py index bbefaf6..9355243 100644 --- a/examples/cyclegan/infer.py +++ b/examples/cyclegan/infer.py @@ -24,26 +24,32 @@ import argparse from PIL import Image from scipy.misc import imsave +import paddle import paddle.fluid as fluid -from paddle.incubate.hapi.model import Model, Input, set_device +from paddle.static import InputSpec as Input from check import check_gpu, check_version from cyclegan import Generator, GeneratorCombine def main(): - place = set_device(FLAGS.device) + place = paddle.set_device(FLAGS.device) fluid.enable_dygraph(place) if FLAGS.dynamic else None + im_shape = [-1, 3, 256, 256] + input_A = Input(im_shape, 'float32', 'input_A') + input_B = Input(im_shape, 'float32', 'input_B') + # Generators g_AB = Generator() g_BA = Generator() - g = GeneratorCombine(g_AB, g_BA, is_train=False) - im_shape = [-1, 3, 256, 256] - input_A = Input(im_shape, 'float32', 'input_A') - input_B = Input(im_shape, 'float32', 'input_B') - g.prepare(inputs=[input_A, input_B], device=FLAGS.device) + g = paddle.Model( + GeneratorCombine( + g_AB, g_BA, is_train=False), + inputs=[input_A, input_B]) + g.prepare() + g.load(FLAGS.init_model, skip_mismatch=True, reset_optimizer=True) out_path = FLAGS.output + "/single" diff --git a/examples/cyclegan/test.py b/examples/cyclegan/test.py index ba7d5c5..7b0059b 100644 --- a/examples/cyclegan/test.py +++ b/examples/cyclegan/test.py @@ -21,8 +21,9 @@ import argparse import numpy as np from scipy.misc import imsave +import paddle import paddle.fluid as fluid -from paddle.incubate.hapi.model import Model, Input, set_device +from paddle.static import InputSpec as Input from check import check_gpu, check_version from cyclegan import Generator, GeneratorCombine @@ -30,18 +31,22 @@ import data as data def main(): - place = set_device(FLAGS.device) + place = paddle.set_device(FLAGS.device) fluid.enable_dygraph(place) if FLAGS.dynamic else None + im_shape = [-1, 3, 256, 256] + input_A = Input(im_shape, 'float32', 'input_A') + input_B = Input(im_shape, 'float32', 'input_B') + # Generators g_AB = Generator() g_BA = Generator() - g = GeneratorCombine(g_AB, g_BA, is_train=False) + g = paddle.Model( + GeneratorCombine( + g_AB, g_BA, is_train=False), + inputs=[input_A, input_B]) - im_shape = [-1, 3, 256, 256] - input_A = Input(im_shape, 'float32', 'input_A') - input_B = Input(im_shape, 'float32', 'input_B') - g.prepare(inputs=[input_A, input_B], device=FLAGS.device) + g.prepare() g.load(FLAGS.init_model, skip_mismatch=True, reset_optimizer=True) if not os.path.exists(FLAGS.output): diff --git a/examples/cyclegan/train.py b/examples/cyclegan/train.py index de9ed63..656616b 100644 --- a/examples/cyclegan/train.py +++ b/examples/cyclegan/train.py @@ -24,7 +24,7 @@ import time import paddle import paddle.fluid as fluid -from paddle.incubate.hapi.model import Model, Input, set_device +from paddle.static import InputSpec as Input from check import check_gpu, check_version from cyclegan import Generator, Discriminator, GeneratorCombine, GLoss, DLoss @@ -48,18 +48,29 @@ def opt(parameters): def main(): - place = set_device(FLAGS.device) + place = paddle.set_device(FLAGS.device) fluid.enable_dygraph(place) if FLAGS.dynamic else None + im_shape = [None, 3, 256, 256] + input_A = Input(im_shape, 'float32', 'input_A') + input_B = Input(im_shape, 'float32', 'input_B') + fake_A = Input(im_shape, 'float32', 'fake_A') + fake_B = Input(im_shape, 'float32', 'fake_B') + # Generators g_AB = Generator() g_BA = Generator() - - # Discriminators d_A = Discriminator() d_B = Discriminator() - g = GeneratorCombine(g_AB, g_BA, d_A, d_B) + g = paddle.Model( + GeneratorCombine(g_AB, g_BA, d_A, d_B), inputs=[input_A, input_B]) + g_AB = paddle.Model(g_AB, [input_A]) + g_BA = paddle.Model(g_BA, [input_B]) + + # Discriminators + d_A = paddle.Model(d_A, [input_B, fake_B]) + d_B = paddle.Model(d_B, [input_A, fake_A]) da_params = d_A.parameters() db_params = d_B.parameters() @@ -69,21 +80,12 @@ def main(): db_optimizer = opt(db_params) g_optimizer = opt(g_params) - im_shape = [None, 3, 256, 256] - input_A = Input(im_shape, 'float32', 'input_A') - input_B = Input(im_shape, 'float32', 'input_B') - fake_A = Input(im_shape, 'float32', 'fake_A') - fake_B = Input(im_shape, 'float32', 'fake_B') - - g_AB.prepare(inputs=[input_A], device=FLAGS.device) - g_BA.prepare(inputs=[input_B], device=FLAGS.device) + g_AB.prepare() + g_BA.prepare() - g.prepare( - g_optimizer, GLoss(), inputs=[input_A, input_B], device=FLAGS.device) - d_A.prepare( - da_optimizer, DLoss(), inputs=[input_B, fake_B], device=FLAGS.device) - d_B.prepare( - db_optimizer, DLoss(), inputs=[input_A, fake_A], device=FLAGS.device) + g.prepare(g_optimizer, GLoss()) + d_A.prepare(da_optimizer, DLoss()) + d_B.prepare(db_optimizer, DLoss()) if FLAGS.resume: g.load(FLAGS.resume) -- GitLab