未验证 提交 22c785db 编写于 作者: L lvmengsi 提交者: GitHub

update_save (#3173)

上级 5322ad1b
......@@ -27,8 +27,8 @@ import six
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import imageio
import copy
from PIL import Image
img_dim = 28
......@@ -158,15 +158,18 @@ def save_test_image(epoch,
image_name).astype('int32')[0]]
inputB_name = "inputB_" + str(epoch) + "_" + A_id2name[np.array(
image_name).astype('int32')[0]]
imageio.imwrite(
os.path.join(out_path, fakeB_name), (
(fake_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
os.path.join(out_path, inputA_name), (
(input_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
os.path.join(out_path, inputB_name), (
(input_B_temp + 1) * 127.5).astype(np.uint8))
res_fakeB = Image.fromarray(((fake_B_temp + 1) * 127.5).astype(
np.uint8))
res_fakeB.save(os.path.join(out_path, fakeB_name))
res_inputA = Image.fromarray(((input_A_temp + 1) * 127.5).astype(
np.uint8))
res_inputA.save(os.path.join(out_path, inputA_name))
res_inputB = Image.fromarray(((input_B_temp + 1) * 127.5).astype(
np.uint8))
res_inputB.save(os.path.join(out_path, inputB_name))
elif cfg.model_net == "StarGAN":
for data in A_test_reader():
real_img, label_org, label_trg, image_name = data[0][
......@@ -199,9 +202,11 @@ def save_test_image(epoch,
images_concat = np.concatenate(images_concat, 1)
image_name_save = "fake_img" + str(epoch) + "_" + str(
np.array(image_name)[0].astype('int32')) + '.jpg'
imageio.imwrite(
os.path.join(out_path, image_name_save), (
(images_concat + 1) * 127.5).astype(np.uint8))
res = Image.fromarray(((images_concat + 1) * 127.5).astype(
np.uint8))
res.save(os.path.join(out_path, image_name_save))
elif cfg.model_net == 'AttGAN' or cfg.model_net == 'STGAN':
for data in A_test_reader():
real_img, label_org, label_trg, image_name = data[0][
......@@ -247,9 +252,10 @@ def save_test_image(epoch,
images_concat = np.concatenate(images_concat, 1)
image_name_save = 'fake_img_' + str(epoch) + '_' + str(
np.array(image_name)[0].astype('int32')) + '.jpg'
image_path = os.path.join(out_path, image_name_save)
imageio.imwrite(image_path, (
(images_concat + 1) * 127.5).astype(np.uint8))
res = Image.fromarray(((images_concat + 1) * 127.5).astype(
np.uint8))
res.save(os.path.join(out_path, image_name_save))
else:
for data_A, data_B in zip(A_test_reader(), B_test_reader()):
......@@ -282,24 +288,30 @@ def save_test_image(epoch,
A_name).astype('int32')[0]]
cycB_name = "cycB_" + str(epoch) + "_" + B_id2name[np.array(
B_name).astype('int32')[0]]
imageio.imwrite(
os.path.join(out_path, fakeB_name), (
(fake_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
os.path.join(out_path, fakeA_name), (
(fake_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
os.path.join(out_path, cycA_name), (
(cyc_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
os.path.join(out_path, cycB_name), (
(cyc_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
os.path.join(out_path, inputA_name), (
(input_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
os.path.join(out_path, inputB_name), (
(input_B_temp + 1) * 127.5).astype(np.uint8))
res_fakeB = Image.fromarray(((fake_B_temp + 1) * 127.5).astype(
np.uint8))
res_fakeB.save(os.path.join(out_path, fakeB_name))
res_fakeA = Image.fromarray(((fake_A_temp + 1) * 127.5).astype(
np.uint8))
res_fakeA.save(os.path.join(out_path, fakeA_name))
res_cycA = Image.fromarray(((cyc_A_temp + 1) * 127.5).astype(
np.uint8))
res_cycA.save(os.path.join(out_path, cycA_name))
res_cycB = Image.fromarray(((cyc_B_temp + 1) * 127.5).astype(
np.uint8))
res_cycB.save(os.path.join(out_path, cycB_name))
res_inputA = Image.fromarray(((input_A_temp + 1) * 127.5).astype(
np.uint8))
res_inputA.save(os.path.join(out_path, inputA_name))
res_inputB = Image.fromarray(((input_B_temp + 1) * 127.5).astype(
np.uint8))
res_inputB.save(os.path.join(out_path, inputB_name))
class ImagePool(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册