提交 f621144c 编写于 作者: X xiaoting 提交者: lvmengsi

fix fetch_list type of gan (#795)

上级 fbdd8fdb
......@@ -399,7 +399,7 @@ prediction, [avg_loss, acc] = train_program()
# 输入的原始图像数据,名称为img,大小为28*28*1
# 标签层,名称为label,对应输入图片的类别标签
# 告知网络传入的数据分为两部分,第一部分是img值,第二部分是label值
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
feeder = fluid.DataFeeder(feed_list=['img', 'label'], place=place)
# 选择Adam优化器
optimizer = optimizer_program()
......
......@@ -441,7 +441,7 @@ prediction, [avg_loss, acc] = train_program()
# 输入的原始图像数据,名称为img,大小为28*28*1
# 标签层,名称为label,对应输入图片的类别标签
# 告知网络传入的数据分为两部分,第一部分是img值,第二部分是label值
feeder = fluid.DataFeeder(feed_list=[‘img’, ‘label’], place=place)
feeder = fluid.DataFeeder(feed_list=['img', 'label'], place=place)
# 选择Adam优化器
optimizer = optimizer_program()
......
......@@ -368,7 +368,7 @@ for pass_id in range(epoch):
# 虚假图片
generated_image = exe.run(g_program,
feed={'noise': noise_data},
fetch_list={g_img})[0]
fetch_list=[g_img])[0]
total_images = np.concatenate([real_image, generated_image])
......@@ -378,7 +378,7 @@ for pass_id in range(epoch):
'img': generated_image,
'label': fake_labels,
},
fetch_list={d_loss})[0][0]
fetch_list=[d_loss])[0][0]
# D 判断真实图片为真的loss
d_loss_2 = exe.run(d_program,
......@@ -386,7 +386,7 @@ for pass_id in range(epoch):
'img': real_image,
'label': real_labels,
},
fetch_list={d_loss})[0][0]
fetch_list=[d_loss])[0][0]
d_loss_n = d_loss_1 + d_loss_2
losses[0].append(d_loss_n)
......@@ -398,7 +398,7 @@ for pass_id in range(epoch):
size=[batch_size, NOISE_SIZE]).astype('float32')
dg_loss_n = exe.run(dg_program,
feed={'noise': noise_data},
fetch_list={dg_loss})[0][0]
fetch_list=[dg_loss])[0][0]
losses[1].append(dg_loss_n)
t_time += (time.time() - s_time)
if batch_id % 10 == 0 :
......@@ -407,7 +407,7 @@ for pass_id in range(epoch):
# 每轮的生成结果
generated_images = exe.run(g_program_test,
feed={'noise': const_n},
fetch_list={g_img})[0]
fetch_list=[g_img])[0]
# 将真实图片和生成图片连接
total_images = np.concatenate([real_image, generated_images])
fig = plot(total_images)
......
......@@ -410,7 +410,7 @@ for pass_id in range(epoch):
# 虚假图片
generated_image = exe.run(g_program,
feed={'noise': noise_data},
fetch_list={g_img})[0]
fetch_list=[g_img])[0]
total_images = np.concatenate([real_image, generated_image])
......@@ -420,7 +420,7 @@ for pass_id in range(epoch):
'img': generated_image,
'label': fake_labels,
},
fetch_list={d_loss})[0][0]
fetch_list=[d_loss])[0][0]
# D 判断真实图片为真的loss
d_loss_2 = exe.run(d_program,
......@@ -428,7 +428,7 @@ for pass_id in range(epoch):
'img': real_image,
'label': real_labels,
},
fetch_list={d_loss})[0][0]
fetch_list=[d_loss])[0][0]
d_loss_n = d_loss_1 + d_loss_2
losses[0].append(d_loss_n)
......@@ -440,7 +440,7 @@ for pass_id in range(epoch):
size=[batch_size, NOISE_SIZE]).astype('float32')
dg_loss_n = exe.run(dg_program,
feed={'noise': noise_data},
fetch_list={dg_loss})[0][0]
fetch_list=[dg_loss])[0][0]
losses[1].append(dg_loss_n)
t_time += (time.time() - s_time)
if batch_id % 10 == 0 :
......@@ -449,7 +449,7 @@ for pass_id in range(epoch):
# 每轮的生成结果
generated_images = exe.run(g_program_test,
feed={'noise': const_n},
fetch_list={g_img})[0]
fetch_list=[g_img])[0]
# 将真实图片和生成图片连接
total_images = np.concatenate([real_image, generated_images])
fig = plot(total_images)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册