未验证 提交 8875bb25 编写于 作者: W whs 提交者: GitHub

Fix cycle gan: (#1829) (#1835)

1. use_cudnn=False
2. fix saving checkponint
3. using compiled program
上级 a8976f5c
...@@ -74,8 +74,8 @@ env CUDA_VISIBLE_DEVICES=0 python train.py ...@@ -74,8 +74,8 @@ env CUDA_VISIBLE_DEVICES=0 python train.py
``` ```
env CUDA_VISIBLE_DEVICE=0 python infer.py \ env CUDA_VISIBLE_DEVICE=0 python infer.py \
--init_model="models/1" --input="./data/inputA/*" \ --init_model="checkpoints/1" --input="./data/inputA/*" \
--output="./output" --input_style A --output="./output"
``` ```
训练150轮的模型预测效果如图2和图3所示: 训练150轮的模型预测效果如图2和图3所示:
......
...@@ -26,8 +26,10 @@ def infer(args): ...@@ -26,8 +26,10 @@ def infer(args):
data_shape = [-1, 3, 256, 256] data_shape = [-1, 3, 256, 256]
input = fluid.layers.data(name='input', shape=data_shape, dtype='float32') input = fluid.layers.data(name='input', shape=data_shape, dtype='float32')
if args.input_style == "A": if args.input_style == "A":
model_name = 'g_a'
fake = build_generator_resnet_9blocks(input, name="g_A") fake = build_generator_resnet_9blocks(input, name="g_A")
elif args.input_style == "B": elif args.input_style == "B":
model_name = 'g_b'
fake = build_generator_resnet_9blocks(input, name="g_B") fake = build_generator_resnet_9blocks(input, name="g_B")
else: else:
raise "Input with style [%s] is not supported." % args.input_style raise "Input with style [%s] is not supported." % args.input_style
...@@ -37,7 +39,7 @@ def infer(args): ...@@ -37,7 +39,7 @@ def infer(args):
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
fluid.io.load_persistables(exe, args.init_model) fluid.io.load_persistables(exe, args.init_model + "/" + model_name)
if not os.path.exists(args.output): if not os.path.exists(args.output):
os.makedirs(args.output) os.makedirs(args.output)
......
...@@ -3,10 +3,12 @@ import paddle.fluid as fluid ...@@ -3,10 +3,12 @@ import paddle.fluid as fluid
import numpy as np import numpy as np
import os import os
use_cudnn = True # cudnn is not better when batch size is 1.
use_cudnn = False
if 'ce_mode' in os.environ: if 'ce_mode' in os.environ:
use_cudnn = False use_cudnn = False
def cal_padding(img_size, stride, filter_size, dilation=1): def cal_padding(img_size, stride, filter_size, dilation=1):
"""Calculate padding size.""" """Calculate padding size."""
valid_filter_size = dilation * (filter_size - 1) + 1 valid_filter_size = dilation * (filter_size - 1) + 1
...@@ -18,6 +20,8 @@ def cal_padding(img_size, stride, filter_size, dilation=1): ...@@ -18,6 +20,8 @@ def cal_padding(img_size, stride, filter_size, dilation=1):
def instance_norm(input, name=None): def instance_norm(input, name=None):
# TODO(lvmengsi@baidu.com): Check the accuracy when using fluid.layers.layer_norm.
# return fluid.layers.layer_norm(input, begin_norm_axis=2)
helper = fluid.layer_helper.LayerHelper("instance_norm", **locals()) helper = fluid.layer_helper.LayerHelper("instance_norm", **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
epsilon = 1e-5 epsilon = 1e-5
......
...@@ -17,7 +17,6 @@ import data_reader ...@@ -17,7 +17,6 @@ import data_reader
from utility import add_arguments, print_arguments, ImagePool from utility import add_arguments, print_arguments, ImagePool
from trainer import * from trainer import *
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
...@@ -36,7 +35,7 @@ add_arg('run_ce', bool, False, "Whether to run for model ce.") ...@@ -36,7 +35,7 @@ add_arg('run_ce', bool, False, "Whether to run for model ce.")
def train(args): def train(args):
max_images_num = data_reader.max_images_num() max_images_num = data_reader.max_images_num()
shuffle=True shuffle = True
if args.run_ce: if args.run_ce:
np.random.seed(10) np.random.seed(10)
fluid.default_startup_program().random_seed = 90 fluid.default_startup_program().random_seed = 90
...@@ -66,9 +65,11 @@ def train(args): ...@@ -66,9 +65,11 @@ def train(args):
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
A_pool = ImagePool() A_pool = ImagePool()
B_pool = ImagePool() B_pool = ImagePool()
A_reader = paddle.batch(data_reader.a_reader(shuffle=shuffle), args.batch_size)() A_reader = paddle.batch(
B_reader = paddle.batch(data_reader.b_reader(shuffle=shuffle), args.batch_size)() data_reader.a_reader(shuffle=shuffle), args.batch_size)()
B_reader = paddle.batch(
data_reader.b_reader(shuffle=shuffle), args.batch_size)()
if not args.run_ce: if not args.run_ce:
A_test_reader = data_reader.a_test_reader() A_test_reader = data_reader.a_test_reader()
B_test_reader = data_reader.b_test_reader() B_test_reader = data_reader.b_test_reader()
...@@ -119,13 +120,13 @@ def train(args): ...@@ -119,13 +120,13 @@ def train(args):
if not os.path.exists(out_path): if not os.path.exists(out_path):
os.makedirs(out_path) os.makedirs(out_path)
fluid.io.save_persistables( fluid.io.save_persistables(
exe, out_path + "/g_a", main_program=g_A_trainer.program, filename="params") exe, out_path + "/g_a", main_program=g_A_trainer.program)
fluid.io.save_persistables( fluid.io.save_persistables(
exe, out_path + "/g_b", main_program=g_B_trainer.program, filename="params") exe, out_path + "/g_b", main_program=g_B_trainer.program)
fluid.io.save_persistables( fluid.io.save_persistables(
exe, out_path + "/d_a", main_program=d_A_trainer.program, filename="params") exe, out_path + "/d_a", main_program=d_A_trainer.program)
fluid.io.save_persistables( fluid.io.save_persistables(
exe, out_path + "/d_b", main_program=d_B_trainer.program, filename="params") exe, out_path + "/d_b", main_program=d_B_trainer.program)
print("saved checkpoint to {}".format(out_path)) print("saved checkpoint to {}".format(out_path))
sys.stdout.flush() sys.stdout.flush()
...@@ -144,8 +145,21 @@ def train(args): ...@@ -144,8 +145,21 @@ def train(args):
if args.init_model: if args.init_model:
init_model() init_model()
losses=[[], []] losses = [[], []]
t_time = 0 t_time = 0
g_A_trainer_program = fluid.CompiledProgram(
g_A_trainer.program).with_data_parallel(
loss_name=g_A_trainer.g_loss_A.name)
g_B_trainer_program = fluid.CompiledProgram(
g_B_trainer.program).with_data_parallel(
loss_name=g_B_trainer.g_loss_B.name)
d_B_trainer_program = fluid.CompiledProgram(
d_B_trainer.program).with_data_parallel(
loss_name=d_B_trainer.d_loss_B.name)
d_A_trainer_program = fluid.CompiledProgram(
d_A_trainer.program).with_data_parallel(
loss_name=d_A_trainer.d_loss_A.name)
for epoch in range(args.epoch): for epoch in range(args.epoch):
batch_id = 0 batch_id = 0
for i in range(max_images_num): for i in range(max_images_num):
...@@ -158,7 +172,7 @@ def train(args): ...@@ -158,7 +172,7 @@ def train(args):
s_time = time.time() s_time = time.time()
# optimize the g_A network # optimize the g_A network
g_A_loss, fake_B_tmp = exe.run( g_A_loss, fake_B_tmp = exe.run(
g_A_trainer.program, g_A_trainer_program,
fetch_list=[g_A_trainer.g_loss_A, g_A_trainer.fake_B], fetch_list=[g_A_trainer.g_loss_A, g_A_trainer.fake_B],
feed={"input_A": tensor_A, feed={"input_A": tensor_A,
"input_B": tensor_B}) "input_B": tensor_B})
...@@ -167,14 +181,14 @@ def train(args): ...@@ -167,14 +181,14 @@ def train(args):
# optimize the d_B network # optimize the d_B network
d_B_loss = exe.run( d_B_loss = exe.run(
d_B_trainer.program, d_B_trainer_program,
fetch_list=[d_B_trainer.d_loss_B], fetch_list=[d_B_trainer.d_loss_B],
feed={"input_B": tensor_B, feed={"input_B": tensor_B,
"fake_pool_B": fake_pool_B})[0] "fake_pool_B": fake_pool_B})[0]
# optimize the g_B network # optimize the g_B network
g_B_loss, fake_A_tmp = exe.run( g_B_loss, fake_A_tmp = exe.run(
g_B_trainer.program, g_B_trainer_program,
fetch_list=[g_B_trainer.g_loss_B, g_B_trainer.fake_A], fetch_list=[g_B_trainer.g_loss_B, g_B_trainer.fake_A],
feed={"input_A": tensor_A, feed={"input_A": tensor_A,
"input_B": tensor_B}) "input_B": tensor_B})
...@@ -183,16 +197,16 @@ def train(args): ...@@ -183,16 +197,16 @@ def train(args):
# optimize the d_A network # optimize the d_A network
d_A_loss = exe.run( d_A_loss = exe.run(
d_A_trainer.program, d_A_trainer_program,
fetch_list=[d_A_trainer.d_loss_A], fetch_list=[d_A_trainer.d_loss_A],
feed={"input_A": tensor_A, feed={"input_A": tensor_A,
"fake_pool_A": fake_pool_A})[0] "fake_pool_A": fake_pool_A})[0]
batch_time = time.time() - s_time batch_time = time.time() - s_time
t_time += batch_time t_time += batch_time
print("epoch{}; batch{}; g_A_loss: {}; d_B_loss: {}; g_B_loss: {}; d_A_loss: {}; " print(
"Batch_time_cost: {:.2f}".format( "epoch{}; batch{}; g_A_loss: {}; d_B_loss: {}; g_B_loss: {}; d_A_loss: {}; "
epoch, batch_id, g_A_loss[0], d_B_loss[0], g_B_loss[0], "Batch_time_cost: {:.2f}".format(epoch, batch_id, g_A_loss[
d_A_loss[0], batch_time)) 0], d_B_loss[0], g_B_loss[0], d_A_loss[0], batch_time))
losses[0].append(g_A_loss[0]) losses[0].append(g_A_loss[0])
losses[1].append(d_A_loss[0]) losses[1].append(d_A_loss[0])
sys.stdout.flush() sys.stdout.flush()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册