提交 b969c14d 编写于 作者: Q qingqing01

Add CycleGAN

上级 6d9e77b9
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from layers import ConvBN, DeConvBN
import paddle.fluid as fluid
from model import Model, Loss
class ResnetBlock(fluid.dygraph.Layer):
def __init__(self, dim, dropout=False):
super(ResnetBlock, self).__init__()
self.dropout = dropout
self.conv0 = ConvBN(dim, dim, 3, 1)
self.conv1 = ConvBN(dim, dim, 3, 1, act=None)
def forward(self, inputs):
out_res = fluid.layers.pad2d(inputs, [1, 1, 1, 1], mode="reflect")
out_res = self.conv0(out_res)
if self.dropout:
out_res = fluid.layers.dropout(out_res, dropout_prob=0.5)
out_res = fluid.layers.pad2d(out_res, [1, 1, 1, 1], mode="reflect")
out_res = self.conv1(out_res)
return out_res + inputs
class ResnetGenerator(fluid.dygraph.Layer):
def __init__(self, input_channel, n_blocks=9, dropout=False):
super(ResnetGenerator, self).__init__()
self.conv0 = ConvBN(input_channel, 32, 7, 1)
self.conv1 = ConvBN(32, 64, 3, 2, padding=1)
self.conv2 = ConvBN(64, 128, 3, 2, padding=1)
dim = 128
self.resnet_blocks = []
for i in range(n_blocks):
block = self.add_sublayer("generator_%d" % (i + 1),
ResnetBlock(dim, dropout))
self.resnet_blocks.append(block)
self.deconv0 = DeConvBN(
dim, 32 * 2, 3, 2, padding=[1, 1], outpadding=[0, 1, 0, 1])
self.deconv1 = DeConvBN(
32 * 2, 32, 3, 2, padding=[1, 1], outpadding=[0, 1, 0, 1])
self.conv3 = ConvBN(
32, input_channel, 7, 1, norm=False, act=False, use_bias=True)
def forward(self, inputs):
pad_input = fluid.layers.pad2d(inputs, [3, 3, 3, 3], mode="reflect")
y = self.conv0(pad_input)
y = self.conv1(y)
y = self.conv2(y)
for resnet_block in self.resnet_blocks:
y = resnet_block(y)
y = self.deconv0(y)
y = self.deconv1(y)
y = fluid.layers.pad2d(y, [3, 3, 3, 3], mode="reflect")
y = self.conv3(y)
y = fluid.layers.tanh(y)
return y
class NLayerDiscriminator(fluid.dygraph.Layer):
def __init__(self, input_channel, d_dims=64, d_nlayers=3):
super(NLayerDiscriminator, self).__init__()
self.conv0 = ConvBN(
input_channel,
d_dims,
4,
2,
1,
norm=False,
use_bias=True,
relufactor=0.2)
nf_mult, nf_mult_prev = 1, 1
self.conv_layers = []
for n in range(1, d_nlayers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
conv = self.add_sublayer(
'discriminator_%d' % (n),
ConvBN(
d_dims * nf_mult_prev,
d_dims * nf_mult,
4,
2,
1,
relufactor=0.2))
self.conv_layers.append(conv)
nf_mult_prev = nf_mult
nf_mult = min(2**d_nlayers, 8)
self.conv4 = ConvBN(
d_dims * nf_mult_prev, d_dims * nf_mult, 4, 1, 1, relufactor=0.2)
self.conv5 = ConvBN(
d_dims * nf_mult,
1,
4,
1,
1,
norm=False,
act=None,
use_bias=True,
relufactor=0.2)
def forward(self, inputs):
y = self.conv0(inputs)
for conv in self.conv_layers:
y = conv(y)
y = self.conv4(y)
y = self.conv5(y)
return y
class Generator(Model):
def __init__(self, input_channel=3):
super(Generator, self).__init__()
self.g = ResnetGenerator(input_channel)
def forward(self, input):
fake = self.g(input)
return fake
class GeneratorCombine(Model):
def __init__(self, g_AB=None, g_BA=None, d_A=None, d_B=None,
is_train=True):
super(GeneratorCombine, self).__init__()
self.g_AB = g_AB
self.g_BA = g_BA
self.is_train = is_train
if self.is_train:
self.d_A = d_A
self.d_B = d_B
def forward(self, input_A, input_B):
# Translate images to the other domain
fake_B = self.g_AB(input_A)
fake_A = self.g_BA(input_B)
# Translate images back to original domain
cyc_A = self.g_BA(fake_B)
cyc_B = self.g_AB(fake_A)
if not self.is_train:
return fake_A, fake_B, cyc_A, cyc_B
# Identity mapping of images
idt_A = self.g_AB(input_B)
idt_B = self.g_BA(input_A)
# Discriminators determines validity of translated images
# d_A(g_AB(A))
valid_A = self.d_A.d(fake_B)
# d_B(g_BA(A))
valid_B = self.d_B.d(fake_A)
return input_A, input_B, fake_A, fake_B, cyc_A, cyc_B, idt_A, idt_B, valid_A, valid_B
class GLoss(Loss):
def __init__(self, lambda_A=10., lambda_B=10., lambda_identity=0.5):
super(GLoss, self).__init__()
self.lambda_A = lambda_A
self.lambda_B = lambda_B
self.lambda_identity = lambda_identity
def forward(self, outputs, labels=None):
input_A, input_B, fake_A, fake_B, cyc_A, cyc_B, idt_A, idt_B, valid_A, valid_B = outputs
def mse(a, b):
return fluid.layers.reduce_mean(fluid.layers.square(a - b))
def mae(a, b): # L1Loss
return fluid.layers.reduce_mean(fluid.layers.abs(a - b))
g_A_loss = mse(valid_A, 1.)
g_B_loss = mse(valid_B, 1.)
g_loss = g_A_loss + g_B_loss
cyc_A_loss = mae(input_A, cyc_A) * self.lambda_A
cyc_B_loss = mae(input_B, cyc_B) * self.lambda_B
cyc_loss = cyc_A_loss + cyc_B_loss
idt_loss_A = mae(input_B, idt_A) * (self.lambda_B *
self.lambda_identity)
idt_loss_B = mae(input_A, idt_B) * (self.lambda_A *
self.lambda_identity)
idt_loss = idt_loss_A + idt_loss_B
loss = cyc_loss + g_loss + idt_loss
return loss
class Discriminator(Model):
def __init__(self, input_channel=3):
super(Discriminator, self).__init__()
self.d = NLayerDiscriminator(input_channel)
def forward(self, real, fake):
pred_real = self.d(real)
pred_fake = self.d(fake)
return pred_real, pred_fake
class DLoss(Loss):
def __init__(self):
super(DLoss, self).__init__()
def forward(self, inputs, labels=None):
pred_real, pred_fake = inputs
loss = fluid.layers.square(pred_fake) + fluid.layers.square(pred_real -
1.)
loss = fluid.layers.reduce_mean(loss / 2.0)
return loss
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import numpy as np
from PIL import Image, ImageOps
DATASET = "cityscapes"
A_LIST_FILE = "./data/" + DATASET + "/trainA.txt"
B_LIST_FILE = "./data/" + DATASET + "/trainB.txt"
A_TEST_LIST_FILE = "./data/" + DATASET + "/testA.txt"
B_TEST_LIST_FILE = "./data/" + DATASET + "/testB.txt"
IMAGES_ROOT = "./data/" + DATASET + "/"
import paddle.fluid as fluid
class Cityscapes(fluid.io.Dataset):
def __init__(self, root_path, file_path, mode='train', return_name=False):
self.root_path = root_path
self.file_path = file_path
self.mode = mode
self.return_name = return_name
self.images = [root_path + l for l in open(file_path, 'r').readlines()]
def _train(self, image):
## Resize
image = image.resize((286, 286), Image.BICUBIC)
## RandomCrop
i = np.random.randint(0, 30)
j = np.random.randint(0, 30)
image = image.crop((i, j, i + 256, j + 256))
# RandomHorizontalFlip
if np.random.rand() > 0.5:
image = ImageOps.mirror(image)
return image
def __getitem__(self, idx):
f = self.images[idx].strip("\n\r\t ")
image = Image.open(f)
if self.mode == 'train':
image = self._train(image)
else:
image = image.resize((256, 256), Image.BICUBIC)
# ToTensor
image = np.array(image).transpose([2, 0, 1]).astype('float32')
image = image / 255.0
# Normalize, mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]
image = (image - 0.5) / 0.5
if self.return_name:
return [image], os.path.basename(f)
else:
return [image]
def __len__(self):
return len(self.images)
def DataA(root=IMAGES_ROOT, fpath=A_LIST_FILE):
"""
Reader of images with A style for training.
"""
return Cityscapes(root, fpath)
def DataB(root=IMAGES_ROOT, fpath=B_LIST_FILE):
"""
Reader of images with B style for training.
"""
return Cityscapes(root, fpath)
def TestDataA(root=IMAGES_ROOT, fpath=A_TEST_LIST_FILE):
"""
Reader of images with A style for training.
"""
return Cityscapes(root, fpath, mode='test', return_name=True)
def TestDataB(root=IMAGES_ROOT, fpath=B_TEST_LIST_FILE):
"""
Reader of images with B style for training.
"""
return Cityscapes(root, fpath, mode='test', return_name=True)
class ImagePool(object):
def __init__(self, pool_size=50):
self.pool = []
self.count = 0
self.pool_size = pool_size
def get(self, image):
if self.count < self.pool_size:
self.pool.append(image)
self.count += 1
return image
else:
p = random.random()
if p > 0.5:
random_id = random.randint(0, self.pool_size - 1)
temp = self.pool[random_id]
self.pool[random_id] = image
return temp
else:
return image
if __name__ == '__main__':
place = fluid.CUDAPlace(0)
#fluid.enable_dygraph(place)
dataset = DataA(shuffle=False)
a_loader = fluid.io.DataLoader(
dataset,
feed_list=[
fluid.data(
name='im', shape=[
None,
2,
2,
], dtype='float32')
],
places=place,
return_list=False,
batch_size=2)
for data in a_loader:
print(data)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import glob
import numpy as np
import argparse
from PIL import Image
from scipy.misc import imsave
import paddle.fluid as fluid
from model import Model, Input, set_device
from cyclegan import Generator, GeneratorCombine
def main():
place = set_device(FLAGS.device)
fluid.enable_dygraph(place) if FLAGS.dynamic else None
# Generators
g_AB = Generator()
g_BA = Generator()
g = GeneratorCombine(g_AB, g_BA, is_train=False)
im_shape = [-1, 3, 256, 256]
input_A = Input(im_shape, 'float32', 'input_A')
input_B = Input(im_shape, 'float32', 'input_B')
g.prepare(inputs=[input_A, input_B])
g.load(FLAGS.init_model, skip_mismatch=True, reset_optimizer=True)
out_path = FLAGS.output + "/single"
if not os.path.exists(out_path):
os.makedirs(out_path)
for f in glob.glob(FLAGS.input):
image_name = os.path.basename(f)
image = Image.open(f).convert('RGB')
image = image.resize((256, 256), Image.BICUBIC)
image = np.array(image) / 127.5 - 1
image = image[:, :, 0:3].astype("float32")
data = image.transpose([2, 0, 1])[np.newaxis, :]
if FLAGS.input_style == "A":
_, fake, _, _ = g.test([data, data])
if FLAGS.input_style == "B":
fake, _, _, _ = g.test([data, data])
fake = np.squeeze(fake[0]).transpose([1, 2, 0])
opath = "{}/fake{}{}".format(out_path, FLAGS.input_style, image_name)
imsave(opath, ((fake + 1) * 127.5).astype(np.uint8))
print("transfer {} to {}".format(f, opath))
if __name__ == "__main__":
parser = argparse.ArgumentParser("CycleGAN inference")
parser.add_argument(
"-d", "--dynamic", action='store_false', help="Enable dygraph mode")
parser.add_argument(
"-p",
"--device",
type=str,
default='gpu',
help="device to use, gpu or cpu")
parser.add_argument(
"-i",
"--input",
type=str,
default='./image/testA/123_A.jpg',
help="input image")
parser.add_argument(
"-o",
'--output',
type=str,
default='output',
help="The test result to be saved to.")
parser.add_argument(
"-m",
"--init_model",
type=str,
default='checkpoint/194',
help="The init model file of directory.")
parser.add_argument(
"-s", "--input_style", type=str, default='A', help="A or B")
FLAGS = parser.parse_args()
print(FLAGS)
main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, BatchNorm
# cudnn is not better when batch size is 1.
use_cudnn = False
import numpy as np
class ConvBN(fluid.dygraph.Layer):
"""docstring for Conv2D"""
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
padding=0,
stddev=0.02,
norm=True,
is_test=False,
act='leaky_relu',
relufactor=0.0,
use_bias=False):
super(ConvBN, self).__init__()
pattr = fluid.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=stddev))
self.conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
use_cudnn=use_cudnn,
param_attr=pattr,
bias_attr=use_bias)
if norm:
self.bn = BatchNorm(
num_filters,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.NormalInitializer(1.0,
0.02)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.0)),
is_test=False,
trainable_statistics=True)
self.relufactor = relufactor
self.norm = norm
self.act = act
def forward(self, inputs):
conv = self.conv(inputs)
if self.norm:
conv = self.bn(conv)
if self.act == 'leaky_relu':
conv = fluid.layers.leaky_relu(conv, alpha=self.relufactor)
elif self.act == 'relu':
conv = fluid.layers.relu(conv)
else:
conv = conv
return conv
class DeConvBN(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
padding=[0, 0],
outpadding=[0, 0, 0, 0],
stddev=0.02,
act='leaky_relu',
norm=True,
is_test=False,
relufactor=0.0,
use_bias=False):
super(DeConvBN, self).__init__()
pattr = fluid.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=stddev))
self._deconv = Conv2DTranspose(
num_channels,
num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
param_attr=pattr,
bias_attr=use_bias)
if norm:
self.bn = BatchNorm(
num_filters,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.NormalInitializer(1.0,
0.02)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.0)),
is_test=False,
trainable_statistics=True)
self.outpadding = outpadding
self.relufactor = relufactor
self.use_bias = use_bias
self.norm = norm
self.act = act
def forward(self, inputs):
conv = self._deconv(inputs)
conv = fluid.layers.pad2d(
conv, paddings=self.outpadding, mode='constant', pad_value=0.0)
if self.norm:
conv = self.bn(conv)
if self.act == 'leaky_relu':
conv = fluid.layers.leaky_relu(conv, alpha=self.relufactor)
elif self.act == 'relu':
conv = fluid.layers.relu(conv)
else:
conv = conv
return conv
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import argparse
import numpy as np
from scipy.misc import imsave
import paddle.fluid as fluid
from model import Model, Input, set_device
from cyclegan import Generator, GeneratorCombine
import data as data
def main():
place = set_device(FLAGS.device)
fluid.enable_dygraph(place) if FLAGS.dynamic else None
# Generators
g_AB = Generator()
g_BA = Generator()
g = GeneratorCombine(g_AB, g_BA, is_train=False)
im_shape = [-1, 3, 256, 256]
input_A = Input(im_shape, 'float32', 'input_A')
input_B = Input(im_shape, 'float32', 'input_B')
g.prepare(inputs=[input_A, input_B])
g.load(FLAGS.init_model, skip_mismatch=True, reset_optimizer=True)
if not os.path.exists(FLAGS.output):
os.makedirs(FLAGS.output)
test_data_A = data.TestDataA()
test_data_B = data.TestDataB()
for i in range(len(test_data_A)):
data_A, A_name = test_data_A[i]
data_B, B_name = test_data_B[i]
data_A = np.array(data_A).astype("float32")
data_B = np.array(data_B).astype("float32")
fake_A, fake_B, cyc_A, cyc_B = g.test([data_A, data_B])
datas = [fake_A, fake_B, cyc_A, cyc_B, data_A, data_B]
odatas = []
for o in datas:
d = np.squeeze(o[0]).transpose([1, 2, 0])
im = ((d + 1) * 127.5).astype(np.uint8)
odatas.append(im)
imsave(FLAGS.output + "/fakeA_" + B_name, odatas[0])
imsave(FLAGS.output + "/fakeB_" + A_name, odatas[1])
imsave(FLAGS.output + "/cycA_" + A_name, odatas[2])
imsave(FLAGS.output + "/cycB_" + B_name, odatas[3])
imsave(FLAGS.output + "/inputA_" + A_name, odatas[4])
imsave(FLAGS.output + "/inputB_" + B_name, odatas[5])
if __name__ == "__main__":
parser = argparse.ArgumentParser("CycleGAN test")
parser.add_argument(
"-d", "--dynamic", action='store_false', help="Enable dygraph mode")
parser.add_argument(
"-p",
"--device",
type=str,
default='gpu',
help="device to use, gpu or cpu")
parser.add_argument(
"-b", "--batch_size", default=1, type=int, help="batch size")
parser.add_argument(
"-o",
'--output',
type=str,
default='output/eval',
help="The test result to be saved to.")
parser.add_argument(
"-m",
"--init_model",
type=str,
default='checkpoint/199',
help="The init model file of directory.")
FLAGS = parser.parse_args()
print(FLAGS)
main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import random
import argparse
import contextlib
import time
import paddle
import paddle.fluid as fluid
from model import Model, Input, set_device
import reader as reader
import data as data
from cyclegan import Generator, Discriminator, GeneratorCombine, GLoss, DLoss
step_per_epoch = 2974
def opt(parameters):
lr_base = 0.0002
bounds = [100, 120, 140, 160, 180]
lr = [1., 0.8, 0.6, 0.4, 0.2, 0.1]
bounds = [i * step_per_epoch for i in bounds]
lr = [i * lr_base for i in lr]
optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bounds, values=lr),
parameter_list=parameters,
beta1=0.5)
return optimizer
def main():
place = set_device(FLAGS.device)
fluid.enable_dygraph(place) if FLAGS.dynamic else None
# Generators
g_AB = Generator()
g_BA = Generator()
# Discriminators
d_A = Discriminator()
d_B = Discriminator()
g = GeneratorCombine(g_AB, g_BA, d_A, d_B)
da_params = d_A.parameters()
db_params = d_B.parameters()
g_params = g_AB.parameters() + g_BA.parameters()
da_optimizer = opt(da_params)
db_optimizer = opt(db_params)
g_optimizer = opt(g_params)
im_shape = [None, 3, 256, 256]
input_A = Input(im_shape, 'float32', 'input_A')
input_B = Input(im_shape, 'float32', 'input_B')
fake_A = Input(im_shape, 'float32', 'fake_A')
fake_B = Input(im_shape, 'float32', 'fake_B')
g_AB.prepare(inputs=[input_A])
g_BA.prepare(inputs=[input_B])
g.prepare(g_optimizer, GLoss(), inputs=[input_A, input_B])
d_A.prepare(da_optimizer, DLoss(), inputs=[input_B, fake_B])
d_B.prepare(db_optimizer, DLoss(), inputs=[input_A, fake_A])
loader_A = fluid.io.DataLoader(
data.DataA(),
feed_list=[x.forward() for x in [input_A]]
if not FLAGS.dynamic else None,
places=place,
shuffle=True,
return_list=True,
use_buffer_reader=True,
batch_size=FLAGS.batch_size)
loader_B = fluid.io.DataLoader(
data.DataB(),
feed_list=[x.forward() for x in [input_B]]
if not FLAGS.dynamic else None,
places=place,
shuffle=True,
return_list=True,
use_buffer_reader=True,
batch_size=FLAGS.batch_size)
A_pool = data.ImagePool()
B_pool = data.ImagePool()
for epoch in range(FLAGS.epoch):
for i, (data_A, data_B) in enumerate(zip(loader_A, loader_B)):
data_A = data_A[0][0] if not FLAGS.dynamic else data_A[0]
data_B = data_B[0][0] if not FLAGS.dynamic else data_B[0]
start = time.time()
fake_B = g_AB.test(data_A)[0]
fake_A = g_BA.test(data_B)[0]
g_loss = g.train([data_A, data_B])[0]
fake_pb = B_pool.get(fake_B)
da_loss = d_A.train([data_B, fake_pb])[0]
fake_pa = A_pool.get(fake_A)
db_loss = d_B.train([data_A, fake_pa])[0]
t = time.time() - start
if i % 20 == 0:
print("epoch: {} | step: {:3d} | g_loss: {:.4f} | " +
"da_loss: {:.4f} | db_loss: {:.4f} | s/step {:.4f}".
format(epoch, i, g_loss[0], da_loss[0], db_loss[0], t))
g.save('{}/{}'.format(FLAGS.checkpoint_path, epoch))
if __name__ == "__main__":
parser = argparse.ArgumentParser("CycleGAN Training on Cityscapes")
parser.add_argument(
"-d", "--dynamic", action='store_false', help="Enable dygraph mode")
parser.add_argument(
"--device", type=str, default='gpu', help="device to use, gpu or cpu")
parser.add_argument(
"-e", "--epoch", default=200, type=int, help="Epoch number")
parser.add_argument(
"-b", "--batch_size", default=1, type=int, help="batch size")
parser.add_argument(
"-o",
"--checkpoint_path",
type=str,
default='checkpoint',
help="path to save checkpoint")
parser.add_argument(
"-r",
"--resume",
default=None,
type=str,
help="checkpoint path to resume")
FLAGS = parser.parse_args()
main()
......@@ -112,9 +112,9 @@ class Loss(object):
def forward(self, outputs, labels):
raise NotImplementedError()
def __call__(self, outputs, labels):
def __call__(self, outputs, labels=None):
labels = to_list(labels)
if in_dygraph_mode():
if in_dygraph_mode() and labels:
labels = [to_variable(l) for l in labels]
losses = to_list(self.forward(to_list(outputs), labels))
if self.average:
......@@ -870,8 +870,6 @@ class Model(fluid.dygraph.Layer):
if not isinstance(inputs, (list, dict, Input)):
raise TypeError(
"'inputs' must be list or dict in static graph mode")
if loss_function and not isinstance(labels, (list, Input)):
raise TypeError("'labels' must be list in static graph mode")
metrics = metrics or []
for metric in to_list(metrics):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册