未验证 提交 22c785db 编写于 作者: L lvmengsi 提交者: GitHub

update_save (#3173)

上级 5322ad1b
...@@ -27,8 +27,8 @@ import six ...@@ -27,8 +27,8 @@ import six
matplotlib.use('agg') matplotlib.use('agg')
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec import matplotlib.gridspec as gridspec
import imageio
import copy import copy
from PIL import Image
img_dim = 28 img_dim = 28
...@@ -158,15 +158,18 @@ def save_test_image(epoch, ...@@ -158,15 +158,18 @@ def save_test_image(epoch,
image_name).astype('int32')[0]] image_name).astype('int32')[0]]
inputB_name = "inputB_" + str(epoch) + "_" + A_id2name[np.array( inputB_name = "inputB_" + str(epoch) + "_" + A_id2name[np.array(
image_name).astype('int32')[0]] image_name).astype('int32')[0]]
imageio.imwrite(
os.path.join(out_path, fakeB_name), ( res_fakeB = Image.fromarray(((fake_B_temp + 1) * 127.5).astype(
(fake_B_temp + 1) * 127.5).astype(np.uint8)) np.uint8))
imageio.imwrite( res_fakeB.save(os.path.join(out_path, fakeB_name))
os.path.join(out_path, inputA_name), (
(input_A_temp + 1) * 127.5).astype(np.uint8)) res_inputA = Image.fromarray(((input_A_temp + 1) * 127.5).astype(
imageio.imwrite( np.uint8))
os.path.join(out_path, inputB_name), ( res_inputA.save(os.path.join(out_path, inputA_name))
(input_B_temp + 1) * 127.5).astype(np.uint8))
res_inputB = Image.fromarray(((input_B_temp + 1) * 127.5).astype(
np.uint8))
res_inputB.save(os.path.join(out_path, inputB_name))
elif cfg.model_net == "StarGAN": elif cfg.model_net == "StarGAN":
for data in A_test_reader(): for data in A_test_reader():
real_img, label_org, label_trg, image_name = data[0][ real_img, label_org, label_trg, image_name = data[0][
...@@ -199,9 +202,11 @@ def save_test_image(epoch, ...@@ -199,9 +202,11 @@ def save_test_image(epoch,
images_concat = np.concatenate(images_concat, 1) images_concat = np.concatenate(images_concat, 1)
image_name_save = "fake_img" + str(epoch) + "_" + str( image_name_save = "fake_img" + str(epoch) + "_" + str(
np.array(image_name)[0].astype('int32')) + '.jpg' np.array(image_name)[0].astype('int32')) + '.jpg'
imageio.imwrite(
os.path.join(out_path, image_name_save), ( res = Image.fromarray(((images_concat + 1) * 127.5).astype(
(images_concat + 1) * 127.5).astype(np.uint8)) np.uint8))
res.save(os.path.join(out_path, image_name_save))
elif cfg.model_net == 'AttGAN' or cfg.model_net == 'STGAN': elif cfg.model_net == 'AttGAN' or cfg.model_net == 'STGAN':
for data in A_test_reader(): for data in A_test_reader():
real_img, label_org, label_trg, image_name = data[0][ real_img, label_org, label_trg, image_name = data[0][
...@@ -247,9 +252,10 @@ def save_test_image(epoch, ...@@ -247,9 +252,10 @@ def save_test_image(epoch,
images_concat = np.concatenate(images_concat, 1) images_concat = np.concatenate(images_concat, 1)
image_name_save = 'fake_img_' + str(epoch) + '_' + str( image_name_save = 'fake_img_' + str(epoch) + '_' + str(
np.array(image_name)[0].astype('int32')) + '.jpg' np.array(image_name)[0].astype('int32')) + '.jpg'
image_path = os.path.join(out_path, image_name_save)
imageio.imwrite(image_path, ( res = Image.fromarray(((images_concat + 1) * 127.5).astype(
(images_concat + 1) * 127.5).astype(np.uint8)) np.uint8))
res.save(os.path.join(out_path, image_name_save))
else: else:
for data_A, data_B in zip(A_test_reader(), B_test_reader()): for data_A, data_B in zip(A_test_reader(), B_test_reader()):
...@@ -282,24 +288,30 @@ def save_test_image(epoch, ...@@ -282,24 +288,30 @@ def save_test_image(epoch,
A_name).astype('int32')[0]] A_name).astype('int32')[0]]
cycB_name = "cycB_" + str(epoch) + "_" + B_id2name[np.array( cycB_name = "cycB_" + str(epoch) + "_" + B_id2name[np.array(
B_name).astype('int32')[0]] B_name).astype('int32')[0]]
imageio.imwrite(
os.path.join(out_path, fakeB_name), ( res_fakeB = Image.fromarray(((fake_B_temp + 1) * 127.5).astype(
(fake_B_temp + 1) * 127.5).astype(np.uint8)) np.uint8))
imageio.imwrite( res_fakeB.save(os.path.join(out_path, fakeB_name))
os.path.join(out_path, fakeA_name), (
(fake_A_temp + 1) * 127.5).astype(np.uint8)) res_fakeA = Image.fromarray(((fake_A_temp + 1) * 127.5).astype(
imageio.imwrite( np.uint8))
os.path.join(out_path, cycA_name), ( res_fakeA.save(os.path.join(out_path, fakeA_name))
(cyc_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite( res_cycA = Image.fromarray(((cyc_A_temp + 1) * 127.5).astype(
os.path.join(out_path, cycB_name), ( np.uint8))
(cyc_B_temp + 1) * 127.5).astype(np.uint8)) res_cycA.save(os.path.join(out_path, cycA_name))
imageio.imwrite(
os.path.join(out_path, inputA_name), ( res_cycB = Image.fromarray(((cyc_B_temp + 1) * 127.5).astype(
(input_A_temp + 1) * 127.5).astype(np.uint8)) np.uint8))
imageio.imwrite( res_cycB.save(os.path.join(out_path, cycB_name))
os.path.join(out_path, inputB_name), (
(input_B_temp + 1) * 127.5).astype(np.uint8)) res_inputA = Image.fromarray(((input_A_temp + 1) * 127.5).astype(
np.uint8))
res_inputA.save(os.path.join(out_path, inputA_name))
res_inputB = Image.fromarray(((input_B_temp + 1) * 127.5).astype(
np.uint8))
res_inputB.save(os.path.join(out_path, inputB_name))
class ImagePool(object): class ImagePool(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册