提交 70128735 编写于 作者: L liuwei1031

disable default memory optimize since it affect the efficiency a lot, test=develop

上级 752c2bed
......@@ -147,19 +147,22 @@ def train(args):
init_model()
losses = [[], []]
t_time = 0
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
g_A_trainer_program = fluid.CompiledProgram(
g_A_trainer.program).with_data_parallel(
loss_name=g_A_trainer.g_loss_A.name)
loss_name=g_A_trainer.g_loss_A.name, build_strategy=build_strategy)
g_B_trainer_program = fluid.CompiledProgram(
g_B_trainer.program).with_data_parallel(
loss_name=g_B_trainer.g_loss_B.name)
loss_name=g_B_trainer.g_loss_B.name, build_strategy=build_strategy)
d_B_trainer_program = fluid.CompiledProgram(
d_B_trainer.program).with_data_parallel(
loss_name=d_B_trainer.d_loss_B.name)
loss_name=d_B_trainer.d_loss_B.name, build_strategy=build_strategy)
d_A_trainer_program = fluid.CompiledProgram(
d_A_trainer.program).with_data_parallel(
loss_name=d_A_trainer.d_loss_A.name)
loss_name=d_A_trainer.d_loss_A.name, build_strategy=build_strategy)
for epoch in range(args.epoch):
batch_id = 0
for i in range(max_images_num):
......
......@@ -13,8 +13,6 @@ class GATrainer():
self.program = fluid.default_main_program().clone()
with fluid.program_guard(self.program):
self.fake_B = build_generator_resnet_9blocks(input_A, name="g_A")
#FIXME set persistable explicitly to pass CE
self.fake_B.persistable = True
self.fake_A = build_generator_resnet_9blocks(input_B, name="g_B")
self.cyc_A = build_generator_resnet_9blocks(self.fake_B, "g_B")
self.cyc_B = build_generator_resnet_9blocks(self.fake_A, "g_A")
......@@ -60,8 +58,6 @@ class GBTrainer():
with fluid.program_guard(self.program):
self.fake_B = build_generator_resnet_9blocks(input_A, name="g_A")
self.fake_A = build_generator_resnet_9blocks(input_B, name="g_B")
#FIXME set persistable explicitly to pass CE
self.fake_A.persistable = True
self.cyc_A = build_generator_resnet_9blocks(self.fake_B, "g_B")
self.cyc_B = build_generator_resnet_9blocks(self.fake_A, "g_A")
self.infer_program = self.program.clone()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册