未验证 提交 8875bb25 编写于 作者: W whs 提交者: GitHub

Fix cycle gan: (#1829) (#1835)

1. use_cudnn=False
2. fix saving checkponint
3. using compiled program
上级 a8976f5c
......@@ -74,8 +74,8 @@ env CUDA_VISIBLE_DEVICES=0 python train.py
```
env CUDA_VISIBLE_DEVICE=0 python infer.py \
--init_model="models/1" --input="./data/inputA/*" \
--output="./output"
--init_model="checkpoints/1" --input="./data/inputA/*" \
--input_style A --output="./output"
```
训练150轮的模型预测效果如图2和图3所示:
......
......@@ -26,8 +26,10 @@ def infer(args):
data_shape = [-1, 3, 256, 256]
input = fluid.layers.data(name='input', shape=data_shape, dtype='float32')
if args.input_style == "A":
model_name = 'g_a'
fake = build_generator_resnet_9blocks(input, name="g_A")
elif args.input_style == "B":
model_name = 'g_b'
fake = build_generator_resnet_9blocks(input, name="g_B")
else:
raise "Input with style [%s] is not supported." % args.input_style
......@@ -37,7 +39,7 @@ def infer(args):
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid.io.load_persistables(exe, args.init_model)
fluid.io.load_persistables(exe, args.init_model + "/" + model_name)
if not os.path.exists(args.output):
os.makedirs(args.output)
......
......@@ -3,10 +3,12 @@ import paddle.fluid as fluid
import numpy as np
import os
use_cudnn = True
# cudnn is not better when batch size is 1.
use_cudnn = False
if 'ce_mode' in os.environ:
use_cudnn = False
def cal_padding(img_size, stride, filter_size, dilation=1):
"""Calculate padding size."""
valid_filter_size = dilation * (filter_size - 1) + 1
......@@ -18,6 +20,8 @@ def cal_padding(img_size, stride, filter_size, dilation=1):
def instance_norm(input, name=None):
# TODO(lvmengsi@baidu.com): Check the accuracy when using fluid.layers.layer_norm.
# return fluid.layers.layer_norm(input, begin_norm_axis=2)
helper = fluid.layer_helper.LayerHelper("instance_norm", **locals())
dtype = helper.input_dtype()
epsilon = 1e-5
......
......@@ -17,7 +17,6 @@ import data_reader
from utility import add_arguments, print_arguments, ImagePool
from trainer import *
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
......@@ -36,7 +35,7 @@ add_arg('run_ce', bool, False, "Whether to run for model ce.")
def train(args):
max_images_num = data_reader.max_images_num()
shuffle=True
shuffle = True
if args.run_ce:
np.random.seed(10)
fluid.default_startup_program().random_seed = 90
......@@ -67,8 +66,10 @@ def train(args):
A_pool = ImagePool()
B_pool = ImagePool()
A_reader = paddle.batch(data_reader.a_reader(shuffle=shuffle), args.batch_size)()
B_reader = paddle.batch(data_reader.b_reader(shuffle=shuffle), args.batch_size)()
A_reader = paddle.batch(
data_reader.a_reader(shuffle=shuffle), args.batch_size)()
B_reader = paddle.batch(
data_reader.b_reader(shuffle=shuffle), args.batch_size)()
if not args.run_ce:
A_test_reader = data_reader.a_test_reader()
B_test_reader = data_reader.b_test_reader()
......@@ -119,13 +120,13 @@ def train(args):
if not os.path.exists(out_path):
os.makedirs(out_path)
fluid.io.save_persistables(
exe, out_path + "/g_a", main_program=g_A_trainer.program, filename="params")
exe, out_path + "/g_a", main_program=g_A_trainer.program)
fluid.io.save_persistables(
exe, out_path + "/g_b", main_program=g_B_trainer.program, filename="params")
exe, out_path + "/g_b", main_program=g_B_trainer.program)
fluid.io.save_persistables(
exe, out_path + "/d_a", main_program=d_A_trainer.program, filename="params")
exe, out_path + "/d_a", main_program=d_A_trainer.program)
fluid.io.save_persistables(
exe, out_path + "/d_b", main_program=d_B_trainer.program, filename="params")
exe, out_path + "/d_b", main_program=d_B_trainer.program)
print("saved checkpoint to {}".format(out_path))
sys.stdout.flush()
......@@ -144,8 +145,21 @@ def train(args):
if args.init_model:
init_model()
losses=[[], []]
losses = [[], []]
t_time = 0
g_A_trainer_program = fluid.CompiledProgram(
g_A_trainer.program).with_data_parallel(
loss_name=g_A_trainer.g_loss_A.name)
g_B_trainer_program = fluid.CompiledProgram(
g_B_trainer.program).with_data_parallel(
loss_name=g_B_trainer.g_loss_B.name)
d_B_trainer_program = fluid.CompiledProgram(
d_B_trainer.program).with_data_parallel(
loss_name=d_B_trainer.d_loss_B.name)
d_A_trainer_program = fluid.CompiledProgram(
d_A_trainer.program).with_data_parallel(
loss_name=d_A_trainer.d_loss_A.name)
for epoch in range(args.epoch):
batch_id = 0
for i in range(max_images_num):
......@@ -158,7 +172,7 @@ def train(args):
s_time = time.time()
# optimize the g_A network
g_A_loss, fake_B_tmp = exe.run(
g_A_trainer.program,
g_A_trainer_program,
fetch_list=[g_A_trainer.g_loss_A, g_A_trainer.fake_B],
feed={"input_A": tensor_A,
"input_B": tensor_B})
......@@ -167,14 +181,14 @@ def train(args):
# optimize the d_B network
d_B_loss = exe.run(
d_B_trainer.program,
d_B_trainer_program,
fetch_list=[d_B_trainer.d_loss_B],
feed={"input_B": tensor_B,
"fake_pool_B": fake_pool_B})[0]
# optimize the g_B network
g_B_loss, fake_A_tmp = exe.run(
g_B_trainer.program,
g_B_trainer_program,
fetch_list=[g_B_trainer.g_loss_B, g_B_trainer.fake_A],
feed={"input_A": tensor_A,
"input_B": tensor_B})
......@@ -183,16 +197,16 @@ def train(args):
# optimize the d_A network
d_A_loss = exe.run(
d_A_trainer.program,
d_A_trainer_program,
fetch_list=[d_A_trainer.d_loss_A],
feed={"input_A": tensor_A,
"fake_pool_A": fake_pool_A})[0]
batch_time = time.time() - s_time
t_time += batch_time
print("epoch{}; batch{}; g_A_loss: {}; d_B_loss: {}; g_B_loss: {}; d_A_loss: {}; "
"Batch_time_cost: {:.2f}".format(
epoch, batch_id, g_A_loss[0], d_B_loss[0], g_B_loss[0],
d_A_loss[0], batch_time))
print(
"epoch{}; batch{}; g_A_loss: {}; d_B_loss: {}; g_B_loss: {}; d_A_loss: {}; "
"Batch_time_cost: {:.2f}".format(epoch, batch_id, g_A_loss[
0], d_B_loss[0], g_B_loss[0], d_A_loss[0], batch_time))
losses[0].append(g_A_loss[0])
losses[1].append(d_A_loss[0])
sys.stdout.flush()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册