未验证 提交 8a0431bb 编写于 作者: L lvmengsi 提交者: GitHub

fix_cycle_pix (#2403)

* fix bug in cycle and pix
上级 3c303e97
......@@ -49,12 +49,12 @@ def infer(args):
input = fluid.layers.data(name='input', shape=data_shape, dtype='float32')
model_name = 'net_G'
if args.model_net == 'cyclegan':
from network.CycleGAN_network import network_G, network_D
from network.CycleGAN_network import CycleGAN_model
model = CycleGAN_model()
if args.input_style == "A":
fake = network_G(input, name="GA", cfg=args)
fake = model.network_G(input, name="GA", cfg=args)
elif args.input_style == "B":
fake = network_G(input, name="GB", cfg=args)
fake = model.network_G(input, name="GB", cfg=args)
else:
raise "Input with style [%s] is not supported." % args.input_style
elif args.model_net == 'Pix2pix':
......
python infer.py --init_model output/checkpoints/199/ --input "data/cityscapes/testA/*" --input_style A --model_net cyclegan --net_G resnet_6block --g_bash_dims 32
python infer.py --init_model output/checkpoints/199/ --input data/cityscapes/testA/* --input_style A --model_net cyclegan --net_G resnet_6block --g_base_dims 32
import os
import argparse
parser = argparse.ArgumentParser(description='the direction of data list')
parser.add_argument(
'--direction', type=str, default='A2B', help='the direction of data list')
def make_pair_data(fileA, file):
def make_pair_data(fileA, file, d):
f = open(fileA, 'r')
lines = f.readlines()
w = open(file, 'w')
......@@ -10,16 +15,22 @@ def make_pair_data(fileA, file):
print(fileA)
fileB = fileA.replace("A", "B")
print(fileB)
if d == 'A2B':
l = fileA + '\t' + fileB + '\n'
elif d == 'B2A':
l = fileB + '\t' + fileA + '\n'
else:
raise NotImplementedError("the direction: [%s] is not support" % d)
w.write(l)
w.close()
if __name__ == "__main__":
args = parser.parse_args()
trainA_file = "./data/cityscapes/trainA.txt"
train_file = "./data/cityscapes/pix2pix_train_list"
make_pair_data(trainA_file, train_file)
make_pair_data(trainA_file, train_file, args.direction)
testA_file = "./data/cityscapes/testA.txt"
test_file = "./data/cityscapes/pix2pix_test_list"
make_pair_data(testA_file, test_file)
make_pair_data(testA_file, test_file, args.direction)
......@@ -88,13 +88,19 @@ class GTrainer():
vars.append(var.name)
self.param = vars
lr = cfg.learning_rate
if cfg.epoch <= 100:
optimizer = fluid.optimizer.Adam(
learning_rate=lr, beta1=0.5, beta2=0.999, name="net_G")
else:
optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.piecewise_decay(
boundaries=[99 * step_per_epoch] +
[x * step_per_epoch for x in range(100, cfg.epoch - 1)],
boundaries=[99 * step_per_epoch] + [
x * step_per_epoch
for x in xrange(100, cfg.epoch - 1)
],
values=[lr] + [
lr * (1.0 - (x - 99.0) / 101.0)
for x in range(100, cfg.epoch)
for x in xrange(100, cfg.epoch)
]),
beta1=0.5,
beta2=0.999,
......@@ -122,13 +128,19 @@ class DATrainer():
self.param = vars
lr = cfg.learning_rate
if cfg.epoch <= 100:
optimizer = fluid.optimizer.Adam(
learning_rate=lr, beta1=0.5, beta2=0.999, name="net_DA")
else:
optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.piecewise_decay(
boundaries=[99 * step_per_epoch] +
[x * step_per_epoch for x in range(100, cfg.epoch - 1)],
boundaries=[99 * step_per_epoch] + [
x * step_per_epoch
for x in xrange(100, cfg.epoch - 1)
],
values=[lr] + [
lr * (1.0 - (x - 99.0) / 101.0)
for x in range(100, cfg.epoch)
for x in xrange(100, cfg.epoch)
]),
beta1=0.5,
beta2=0.999,
......@@ -155,13 +167,19 @@ class DBTrainer():
vars.append(var.name)
self.param = vars
lr = 0.0002
if cfg.epoch <= 100:
optimizer = fluid.optimizer.Adam(
learning_rate=lr, beta1=0.5, beta2=0.999, name="net_DA")
else:
optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.piecewise_decay(
boundaries=[99 * step_per_epoch] +
[x * step_per_epoch for x in range(100, cfg.epoch - 1)],
boundaries=[99 * step_per_epoch] + [
x * step_per_epoch
for x in xrange(100, cfg.epoch - 1)
],
values=[lr] + [
lr * (1.0 - (x - 99.0) / 101.0)
for x in range(100, cfg.epoch)
for x in xrange(100, cfg.epoch)
]),
beta1=0.5,
beta2=0.999,
......
......@@ -70,10 +70,16 @@ class GTrainer():
"generator"):
vars.append(var.name)
self.param = vars
if cfg.epoch <= 100:
optimizer = fluid.optimizer.Adam(
learning_rate=lr, beta1=0.5, beta2=0.999, name="net_G")
else:
optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.piecewise_decay(
boundaries=[99 * step_per_epoch] +
[x * step_per_epoch for x in range(100, cfg.epoch - 1)],
boundaries=[99 * step_per_epoch] + [
x * step_per_epoch
for x in range(100, cfg.epoch - 1)
],
values=[lr] + [
lr * (1.0 - (x - 99.0) / 101.0)
for x in range(100, cfg.epoch)
......@@ -142,10 +148,16 @@ class DTrainer():
vars.append(var.name)
self.param = vars
if cfg.epoch <= 100:
optimizer = fluid.optimizer.Adam(
learning_rate=lr, beta1=0.5, beta2=0.999, name="net_D")
else:
optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.piecewise_decay(
boundaries=[99 * step_per_epoch] +
[x * step_per_epoch for x in range(100, cfg.epoch - 1)],
boundaries=[99 * step_per_epoch] + [
x * step_per_epoch
for x in range(100, cfg.epoch - 1)
],
values=[lr] + [
lr * (1.0 - (x - 99.0) / 101.0)
for x in range(100, cfg.epoch)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册