未验证 提交 344a0fd5 编写于 作者: C ceci3 提交者: GitHub

fix cyclegan (#4253)

上级 bc1fe15c
......@@ -305,16 +305,22 @@ def infer(args):
id2name = test_reader.id2name
for data in loader():
real_img, image_name = data[0]['input'], data[0]['image_name']
image_name = id2name[np.array(image_name).astype('int32')[0]]
print("read: ", image_name)
image_names = []
for name in image_name:
image_names.append(id2name[np.array(name).astype('int32')[0]])
print("read: ", image_names)
fake_temp = exe.run(fetch_list=[fake.name],
feed={"input": real_img})
fake_temp = np.squeeze(fake_temp[0]).transpose([1, 2, 0])
input_temp = np.squeeze(np.array(real_img)[0]).transpose([1, 2, 0])
fake_temp = save_batch_image(fake_temp[0])
input_temp = save_batch_image(np.array(real_img))
imageio.imwrite(
os.path.join(args.output, "fake_" + image_name), (
(fake_temp + 1) * 127.5).astype(np.uint8))
for i, name in enumerate(image_names):
imageio.imwrite(
os.path.join(args.output, "fake_" + name), (
(fake_temp[i] + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
os.path.join(args.output, "input_" + name), (
(input_temp[i] + 1) * 127.5).astype(np.uint8))
elif args.model_net == 'SPADE':
test_reader = triplex_reader_creator(
image_dir=args.dataset_dir,
......
......@@ -323,6 +323,10 @@ class CycleGAN(object):
fake_pool_B = B_pool.pool_image(fake_B_tmp)
fake_pool_A = A_pool.pool_image(fake_A_tmp)
if self.cfg.enable_ce:
fake_pool_B = fake_B_tmp
fake_pool_A = fake_A_tmp
# optimize the d_A network
d_A_loss = exe.run(
d_A_trainer_program,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册