提交 d0b9667b 编写于 作者: Q qingqing01

Move cyclegan to examples and change train/test/eval to train_batch/test_batch/eval_batch

上级 d8541eac
...@@ -18,9 +18,10 @@ from __future__ import print_function ...@@ -18,9 +18,10 @@ from __future__ import print_function
import numpy as np import numpy as np
from layers import ConvBN, DeConvBN
import paddle.fluid as fluid import paddle.fluid as fluid
from model import Model, Loss from hapi.model import Model, Loss
from layers import ConvBN, DeConvBN
class ResnetBlock(fluid.dygraph.Layer): class ResnetBlock(fluid.dygraph.Layer):
......
...@@ -20,6 +20,8 @@ import random ...@@ -20,6 +20,8 @@ import random
import numpy as np import numpy as np
from PIL import Image, ImageOps from PIL import Image, ImageOps
import paddle
DATASET = "cityscapes" DATASET = "cityscapes"
A_LIST_FILE = "./data/" + DATASET + "/trainA.txt" A_LIST_FILE = "./data/" + DATASET + "/trainA.txt"
B_LIST_FILE = "./data/" + DATASET + "/trainB.txt" B_LIST_FILE = "./data/" + DATASET + "/trainB.txt"
...@@ -27,8 +29,6 @@ A_TEST_LIST_FILE = "./data/" + DATASET + "/testA.txt" ...@@ -27,8 +29,6 @@ A_TEST_LIST_FILE = "./data/" + DATASET + "/testA.txt"
B_TEST_LIST_FILE = "./data/" + DATASET + "/testB.txt" B_TEST_LIST_FILE = "./data/" + DATASET + "/testB.txt"
IMAGES_ROOT = "./data/" + DATASET + "/" IMAGES_ROOT = "./data/" + DATASET + "/"
import paddle.fluid as fluid
class Cityscapes(paddle.io.Dataset): class Cityscapes(paddle.io.Dataset):
def __init__(self, root_path, file_path, mode='train', return_name=False): def __init__(self, root_path, file_path, mode='train', return_name=False):
......
...@@ -25,9 +25,9 @@ from PIL import Image ...@@ -25,9 +25,9 @@ from PIL import Image
from scipy.misc import imsave from scipy.misc import imsave
import paddle.fluid as fluid import paddle.fluid as fluid
from check import check_gpu, check_version from hapi.model import Model, Input, set_device
from model import Model, Input, set_device from check import check_gpu, check_version
from cyclegan import Generator, GeneratorCombine from cyclegan import Generator, GeneratorCombine
...@@ -43,7 +43,7 @@ def main(): ...@@ -43,7 +43,7 @@ def main():
im_shape = [-1, 3, 256, 256] im_shape = [-1, 3, 256, 256]
input_A = Input(im_shape, 'float32', 'input_A') input_A = Input(im_shape, 'float32', 'input_A')
input_B = Input(im_shape, 'float32', 'input_B') input_B = Input(im_shape, 'float32', 'input_B')
g.prepare(inputs=[input_A, input_B]) g.prepare(inputs=[input_A, input_B], device=FLAGS.device)
g.load(FLAGS.init_model, skip_mismatch=True, reset_optimizer=True) g.load(FLAGS.init_model, skip_mismatch=True, reset_optimizer=True)
out_path = FLAGS.output + "/single" out_path = FLAGS.output + "/single"
...@@ -59,10 +59,10 @@ def main(): ...@@ -59,10 +59,10 @@ def main():
data = image.transpose([2, 0, 1])[np.newaxis, :] data = image.transpose([2, 0, 1])[np.newaxis, :]
if FLAGS.input_style == "A": if FLAGS.input_style == "A":
_, fake, _, _ = g.test([data, data]) _, fake, _, _ = g.test_batch([data, data])
if FLAGS.input_style == "B": if FLAGS.input_style == "B":
fake, _, _, _ = g.test([data, data]) fake, _, _, _ = g.test_batch([data, data])
fake = np.squeeze(fake[0]).transpose([1, 2, 0]) fake = np.squeeze(fake[0]).transpose([1, 2, 0])
...@@ -74,7 +74,7 @@ def main(): ...@@ -74,7 +74,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser("CycleGAN inference") parser = argparse.ArgumentParser("CycleGAN inference")
parser.add_argument( parser.add_argument(
"-d", "--dynamic", action='store_false', help="Enable dygraph mode") "-d", "--dynamic", action='store_true', help="Enable dygraph mode")
parser.add_argument( parser.add_argument(
"-p", "-p",
"--device", "--device",
......
...@@ -22,9 +22,9 @@ import numpy as np ...@@ -22,9 +22,9 @@ import numpy as np
from scipy.misc import imsave from scipy.misc import imsave
import paddle.fluid as fluid import paddle.fluid as fluid
from check import check_gpu, check_version from hapi.model import Model, Input, set_device
from model import Model, Input, set_device from check import check_gpu, check_version
from cyclegan import Generator, GeneratorCombine from cyclegan import Generator, GeneratorCombine
import data as data import data as data
...@@ -41,7 +41,7 @@ def main(): ...@@ -41,7 +41,7 @@ def main():
im_shape = [-1, 3, 256, 256] im_shape = [-1, 3, 256, 256]
input_A = Input(im_shape, 'float32', 'input_A') input_A = Input(im_shape, 'float32', 'input_A')
input_B = Input(im_shape, 'float32', 'input_B') input_B = Input(im_shape, 'float32', 'input_B')
g.prepare(inputs=[input_A, input_B]) g.prepare(inputs=[input_A, input_B], device=FLAGS.device)
g.load(FLAGS.init_model, skip_mismatch=True, reset_optimizer=True) g.load(FLAGS.init_model, skip_mismatch=True, reset_optimizer=True)
if not os.path.exists(FLAGS.output): if not os.path.exists(FLAGS.output):
...@@ -56,7 +56,7 @@ def main(): ...@@ -56,7 +56,7 @@ def main():
data_A = np.array(data_A).astype("float32") data_A = np.array(data_A).astype("float32")
data_B = np.array(data_B).astype("float32") data_B = np.array(data_B).astype("float32")
fake_A, fake_B, cyc_A, cyc_B = g.test([data_A, data_B]) fake_A, fake_B, cyc_A, cyc_B = g.test_batch([data_A, data_B])
datas = [fake_A, fake_B, cyc_A, cyc_B, data_A, data_B] datas = [fake_A, fake_B, cyc_A, cyc_B, data_A, data_B]
odatas = [] odatas = []
...@@ -75,7 +75,7 @@ def main(): ...@@ -75,7 +75,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser("CycleGAN test") parser = argparse.ArgumentParser("CycleGAN test")
parser.add_argument( parser.add_argument(
"-d", "--dynamic", action='store_false', help="Enable dygraph mode") "-d", "--dynamic", action='store_true', help="Enable dygraph mode")
parser.add_argument( parser.add_argument(
"-p", "-p",
"--device", "--device",
......
...@@ -24,12 +24,11 @@ import time ...@@ -24,12 +24,11 @@ import time
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from check import check_gpu, check_version from hapi.model import Model, Input, set_device
from model import Model, Input, set_device
import data as data from check import check_gpu, check_version
from cyclegan import Generator, Discriminator, GeneratorCombine, GLoss, DLoss from cyclegan import Generator, Discriminator, GeneratorCombine, GLoss, DLoss
import data as data
step_per_epoch = 2974 step_per_epoch = 2974
...@@ -76,12 +75,15 @@ def main(): ...@@ -76,12 +75,15 @@ def main():
fake_A = Input(im_shape, 'float32', 'fake_A') fake_A = Input(im_shape, 'float32', 'fake_A')
fake_B = Input(im_shape, 'float32', 'fake_B') fake_B = Input(im_shape, 'float32', 'fake_B')
g_AB.prepare(inputs=[input_A]) g_AB.prepare(inputs=[input_A], device=FLAGS.device)
g_BA.prepare(inputs=[input_B]) g_BA.prepare(inputs=[input_B], device=FLAGS.device)
g.prepare(g_optimizer, GLoss(), inputs=[input_A, input_B]) g.prepare(g_optimizer, GLoss(), inputs=[input_A, input_B],
d_A.prepare(da_optimizer, DLoss(), inputs=[input_B, fake_B]) device=FLAGS.device)
d_B.prepare(db_optimizer, DLoss(), inputs=[input_A, fake_A]) d_A.prepare(da_optimizer, DLoss(), inputs=[input_B, fake_B],
device=FLAGS.device)
d_B.prepare(db_optimizer, DLoss(), inputs=[input_A, fake_A],
device=FLAGS.device)
if FLAGS.resume: if FLAGS.resume:
g.load(FLAGS.resume) g.load(FLAGS.resume)
...@@ -108,14 +110,14 @@ def main(): ...@@ -108,14 +110,14 @@ def main():
data_B = data_B[0][0] if not FLAGS.dynamic else data_B[0] data_B = data_B[0][0] if not FLAGS.dynamic else data_B[0]
start = time.time() start = time.time()
fake_B = g_AB.test(data_A)[0] fake_B = g_AB.test_batch(data_A)[0]
fake_A = g_BA.test(data_B)[0] fake_A = g_BA.test_batch(data_B)[0]
g_loss = g.train([data_A, data_B])[0] g_loss = g.train_batch([data_A, data_B])[0]
fake_pb = B_pool.get(fake_B) fake_pb = B_pool.get(fake_B)
da_loss = d_A.train([data_B, fake_pb])[0] da_loss = d_A.train_batch([data_B, fake_pb])[0]
fake_pa = A_pool.get(fake_A) fake_pa = A_pool.get(fake_A)
db_loss = d_B.train([data_A, fake_pa])[0] db_loss = d_B.train_batch([data_A, fake_pa])[0]
t = time.time() - start t = time.time() - start
if i % 20 == 0: if i % 20 == 0:
...@@ -128,7 +130,7 @@ def main(): ...@@ -128,7 +130,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser("CycleGAN Training on Cityscapes") parser = argparse.ArgumentParser("CycleGAN Training on Cityscapes")
parser.add_argument( parser.add_argument(
"-d", "--dynamic", action='store_false', help="Enable dygraph mode") "-d", "--dynamic", action='store_true', help="Enable dygraph mode")
parser.add_argument( parser.add_argument(
"-p", "-p",
"--device", "--device",
......
...@@ -193,17 +193,17 @@ class StaticGraphAdapter(object): ...@@ -193,17 +193,17 @@ class StaticGraphAdapter(object):
def mode(self, value): def mode(self, value):
self.model.mode = value self.model.mode = value
def train(self, inputs, labels=None): def train_batch(self, inputs, labels=None):
assert self.model._optimizer, \ assert self.model._optimizer, \
"model not ready, please call `model.prepare()` first" "model not ready, please call `model.prepare()` first"
self.mode = 'train' self.mode = 'train'
return self._run(inputs, labels) return self._run(inputs, labels)
def eval(self, inputs, labels=None): def eval_batch(self, inputs, labels=None):
self.mode = 'eval' self.mode = 'eval'
return self._run(inputs, labels) return self._run(inputs, labels)
def test(self, inputs): def test_batch(self, inputs):
self.mode = 'test' self.mode = 'test'
return self._run(inputs, None) return self._run(inputs, None)
...@@ -567,7 +567,7 @@ class DynamicGraphAdapter(object): ...@@ -567,7 +567,7 @@ class DynamicGraphAdapter(object):
self.model.mode = value self.model.mode = value
# TODO multi device in dygraph mode not implemented at present time # TODO multi device in dygraph mode not implemented at present time
def train(self, inputs, labels=None): def train_batch(self, inputs, labels=None):
assert self.model._optimizer, \ assert self.model._optimizer, \
"model not ready, please call `model.prepare()` first" "model not ready, please call `model.prepare()` first"
super(Model, self.model).train() super(Model, self.model).train()
...@@ -600,7 +600,7 @@ class DynamicGraphAdapter(object): ...@@ -600,7 +600,7 @@ class DynamicGraphAdapter(object):
return ([to_numpy(l) for l in losses], metrics) \ return ([to_numpy(l) for l in losses], metrics) \
if len(metrics) > 0 else [to_numpy(l) for l in losses] if len(metrics) > 0 else [to_numpy(l) for l in losses]
def eval(self, inputs, labels=None): def eval_batch(self, inputs, labels=None):
super(Model, self.model).eval() super(Model, self.model).eval()
self.mode = 'eval' self.mode = 'eval'
inputs = to_list(inputs) inputs = to_list(inputs)
...@@ -641,7 +641,7 @@ class DynamicGraphAdapter(object): ...@@ -641,7 +641,7 @@ class DynamicGraphAdapter(object):
return ([to_numpy(l) for l in losses], metrics) \ return ([to_numpy(l) for l in losses], metrics) \
if len(metrics) > 0 else [to_numpy(l) for l in losses] if len(metrics) > 0 else [to_numpy(l) for l in losses]
def test(self, inputs): def test_batch(self, inputs):
super(Model, self.model).eval() super(Model, self.model).eval()
self.mode = 'test' self.mode = 'test'
inputs = [to_variable(x) for x in to_list(inputs)] inputs = [to_variable(x) for x in to_list(inputs)]
...@@ -740,14 +740,14 @@ class Model(fluid.dygraph.Layer): ...@@ -740,14 +740,14 @@ class Model(fluid.dygraph.Layer):
else: else:
self._adapter = StaticGraphAdapter(self) self._adapter = StaticGraphAdapter(self)
def train(self, *args, **kwargs): def train_batch(self, *args, **kwargs):
return self._adapter.train(*args, **kwargs) return self._adapter.train_batch(*args, **kwargs)
def eval(self, *args, **kwargs): def eval_batch(self, *args, **kwargs):
return self._adapter.eval(*args, **kwargs) return self._adapter.eval_batch(*args, **kwargs)
def test(self, *args, **kwargs): def test_batch(self, *args, **kwargs):
return self._adapter.test(*args, **kwargs) return self._adapter.test_batch(*args, **kwargs)
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
if ParallelEnv().local_rank == 0: if ParallelEnv().local_rank == 0:
...@@ -1173,7 +1173,7 @@ class Model(fluid.dygraph.Layer): ...@@ -1173,7 +1173,7 @@ class Model(fluid.dygraph.Layer):
outputs = [] outputs = []
for data in tqdm.tqdm(loader): for data in tqdm.tqdm(loader):
data = flatten(data) data = flatten(data)
outputs.append(self.test(data[:len(self._inputs)])) outputs.append(self.test_batch(data[:len(self._inputs)]))
# NOTE: for lod tensor output, we should not stack outputs # NOTE: for lod tensor output, we should not stack outputs
# for stacking may loss its detail info # for stacking may loss its detail info
...@@ -1187,18 +1187,6 @@ class Model(fluid.dygraph.Layer): ...@@ -1187,18 +1187,6 @@ class Model(fluid.dygraph.Layer):
outputs = [o[:len(test_loader.dataset)] for o in outputs] outputs = [o[:len(test_loader.dataset)] for o in outputs]
return outputs return outputs
def set_eval_data(self, eval_data):
"""
Args:
eval_data (Dataset|DataLoader|None): An iterable data loader is used for
eval. An instance of paddle.io.Dataset or
paddle.io.Dataloader is recomended.
"""
assert isinstance(
eval_data,
DataLoader), "eval_data must be a instance of Dataloader!"
self._test_dataloader = eval_data
def _run_one_epoch(self, def _run_one_epoch(self,
data_loader, data_loader,
callbacks, callbacks,
...@@ -1235,11 +1223,11 @@ class Model(fluid.dygraph.Layer): ...@@ -1235,11 +1223,11 @@ class Model(fluid.dygraph.Layer):
callbacks.on_batch_begin(mode, step, logs) callbacks.on_batch_begin(mode, step, logs)
if mode == 'train': if mode == 'train':
outs = self.train(data[:len(self._inputs)], outs = self.train_batch(data[:len(self._inputs)],
data[len(self._inputs):]) data[len(self._inputs):])
else: else:
outs = self.eval(data[:len(self._inputs)], outs = self.eval_batch(data[:len(self._inputs)],
data[len(self._inputs):]) data[len(self._inputs):])
# losses # losses
loss = outs[0] if self._metrics else outs loss = outs[0] if self._metrics else outs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册