提交 f621144c 编写于 作者: X xiaoting 提交者: lvmengsi

fix fetch_list type of gan (#795)

上级 fbdd8fdb
...@@ -399,7 +399,7 @@ prediction, [avg_loss, acc] = train_program() ...@@ -399,7 +399,7 @@ prediction, [avg_loss, acc] = train_program()
# 输入的原始图像数据,名称为img,大小为28*28*1 # 输入的原始图像数据,名称为img,大小为28*28*1
# 标签层,名称为label,对应输入图片的类别标签 # 标签层,名称为label,对应输入图片的类别标签
# 告知网络传入的数据分为两部分,第一部分是img值,第二部分是label值 # 告知网络传入的数据分为两部分,第一部分是img值,第二部分是label值
feeder = fluid.DataFeeder(feed_list=[img, label], place=place) feeder = fluid.DataFeeder(feed_list=['img', 'label'], place=place)
# 选择Adam优化器 # 选择Adam优化器
optimizer = optimizer_program() optimizer = optimizer_program()
......
...@@ -441,7 +441,7 @@ prediction, [avg_loss, acc] = train_program() ...@@ -441,7 +441,7 @@ prediction, [avg_loss, acc] = train_program()
# 输入的原始图像数据,名称为img,大小为28*28*1 # 输入的原始图像数据,名称为img,大小为28*28*1
# 标签层,名称为label,对应输入图片的类别标签 # 标签层,名称为label,对应输入图片的类别标签
# 告知网络传入的数据分为两部分,第一部分是img值,第二部分是label值 # 告知网络传入的数据分为两部分,第一部分是img值,第二部分是label值
feeder = fluid.DataFeeder(feed_list=[‘img’, ‘label’], place=place) feeder = fluid.DataFeeder(feed_list=['img', 'label'], place=place)
# 选择Adam优化器 # 选择Adam优化器
optimizer = optimizer_program() optimizer = optimizer_program()
......
...@@ -368,7 +368,7 @@ for pass_id in range(epoch): ...@@ -368,7 +368,7 @@ for pass_id in range(epoch):
# 虚假图片 # 虚假图片
generated_image = exe.run(g_program, generated_image = exe.run(g_program,
feed={'noise': noise_data}, feed={'noise': noise_data},
fetch_list={g_img})[0] fetch_list=[g_img])[0]
total_images = np.concatenate([real_image, generated_image]) total_images = np.concatenate([real_image, generated_image])
...@@ -378,7 +378,7 @@ for pass_id in range(epoch): ...@@ -378,7 +378,7 @@ for pass_id in range(epoch):
'img': generated_image, 'img': generated_image,
'label': fake_labels, 'label': fake_labels,
}, },
fetch_list={d_loss})[0][0] fetch_list=[d_loss])[0][0]
# D 判断真实图片为真的loss # D 判断真实图片为真的loss
d_loss_2 = exe.run(d_program, d_loss_2 = exe.run(d_program,
...@@ -386,7 +386,7 @@ for pass_id in range(epoch): ...@@ -386,7 +386,7 @@ for pass_id in range(epoch):
'img': real_image, 'img': real_image,
'label': real_labels, 'label': real_labels,
}, },
fetch_list={d_loss})[0][0] fetch_list=[d_loss])[0][0]
d_loss_n = d_loss_1 + d_loss_2 d_loss_n = d_loss_1 + d_loss_2
losses[0].append(d_loss_n) losses[0].append(d_loss_n)
...@@ -398,7 +398,7 @@ for pass_id in range(epoch): ...@@ -398,7 +398,7 @@ for pass_id in range(epoch):
size=[batch_size, NOISE_SIZE]).astype('float32') size=[batch_size, NOISE_SIZE]).astype('float32')
dg_loss_n = exe.run(dg_program, dg_loss_n = exe.run(dg_program,
feed={'noise': noise_data}, feed={'noise': noise_data},
fetch_list={dg_loss})[0][0] fetch_list=[dg_loss])[0][0]
losses[1].append(dg_loss_n) losses[1].append(dg_loss_n)
t_time += (time.time() - s_time) t_time += (time.time() - s_time)
if batch_id % 10 == 0 : if batch_id % 10 == 0 :
...@@ -407,7 +407,7 @@ for pass_id in range(epoch): ...@@ -407,7 +407,7 @@ for pass_id in range(epoch):
# 每轮的生成结果 # 每轮的生成结果
generated_images = exe.run(g_program_test, generated_images = exe.run(g_program_test,
feed={'noise': const_n}, feed={'noise': const_n},
fetch_list={g_img})[0] fetch_list=[g_img])[0]
# 将真实图片和生成图片连接 # 将真实图片和生成图片连接
total_images = np.concatenate([real_image, generated_images]) total_images = np.concatenate([real_image, generated_images])
fig = plot(total_images) fig = plot(total_images)
......
...@@ -410,7 +410,7 @@ for pass_id in range(epoch): ...@@ -410,7 +410,7 @@ for pass_id in range(epoch):
# 虚假图片 # 虚假图片
generated_image = exe.run(g_program, generated_image = exe.run(g_program,
feed={'noise': noise_data}, feed={'noise': noise_data},
fetch_list={g_img})[0] fetch_list=[g_img])[0]
total_images = np.concatenate([real_image, generated_image]) total_images = np.concatenate([real_image, generated_image])
...@@ -420,7 +420,7 @@ for pass_id in range(epoch): ...@@ -420,7 +420,7 @@ for pass_id in range(epoch):
'img': generated_image, 'img': generated_image,
'label': fake_labels, 'label': fake_labels,
}, },
fetch_list={d_loss})[0][0] fetch_list=[d_loss])[0][0]
# D 判断真实图片为真的loss # D 判断真实图片为真的loss
d_loss_2 = exe.run(d_program, d_loss_2 = exe.run(d_program,
...@@ -428,7 +428,7 @@ for pass_id in range(epoch): ...@@ -428,7 +428,7 @@ for pass_id in range(epoch):
'img': real_image, 'img': real_image,
'label': real_labels, 'label': real_labels,
}, },
fetch_list={d_loss})[0][0] fetch_list=[d_loss])[0][0]
d_loss_n = d_loss_1 + d_loss_2 d_loss_n = d_loss_1 + d_loss_2
losses[0].append(d_loss_n) losses[0].append(d_loss_n)
...@@ -440,7 +440,7 @@ for pass_id in range(epoch): ...@@ -440,7 +440,7 @@ for pass_id in range(epoch):
size=[batch_size, NOISE_SIZE]).astype('float32') size=[batch_size, NOISE_SIZE]).astype('float32')
dg_loss_n = exe.run(dg_program, dg_loss_n = exe.run(dg_program,
feed={'noise': noise_data}, feed={'noise': noise_data},
fetch_list={dg_loss})[0][0] fetch_list=[dg_loss])[0][0]
losses[1].append(dg_loss_n) losses[1].append(dg_loss_n)
t_time += (time.time() - s_time) t_time += (time.time() - s_time)
if batch_id % 10 == 0 : if batch_id % 10 == 0 :
...@@ -449,7 +449,7 @@ for pass_id in range(epoch): ...@@ -449,7 +449,7 @@ for pass_id in range(epoch):
# 每轮的生成结果 # 每轮的生成结果
generated_images = exe.run(g_program_test, generated_images = exe.run(g_program_test,
feed={'noise': const_n}, feed={'noise': const_n},
fetch_list={g_img})[0] fetch_list=[g_img])[0]
# 将真实图片和生成图片连接 # 将真实图片和生成图片连接
total_images = np.concatenate([real_image, generated_images]) total_images = np.concatenate([real_image, generated_images])
fig = plot(total_images) fig = plot(total_images)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册