未验证 提交 55c5db2f 编写于 作者: C chengduo 提交者: GitHub

fix fetch_list type (#786)

上级 a8f1b403
......@@ -124,7 +124,7 @@ def train(args):
total_label = np.concatenate([real_labels, fake_labels])
s_time = time.time()
generated_image = exe.run(
g_program, feed={'noise': noise_data}, fetch_list={g_img})[0]
g_program, feed={'noise': noise_data}, fetch_list=[g_img])[0]
total_images = np.concatenate([real_image, generated_image])
......@@ -134,7 +134,7 @@ def train(args):
'img': generated_image,
'label': fake_labels,
},
fetch_list={d_loss})[0][0]
fetch_list=[d_loss])[0][0]
d_loss_2 = exe.run(
d_program,
......@@ -142,7 +142,7 @@ def train(args):
'img': real_image,
'label': real_labels,
},
fetch_list={d_loss})[0][0]
fetch_list=[d_loss])[0][0]
d_loss_n = d_loss_1 + d_loss_2
losses[0].append(d_loss_n)
......@@ -153,7 +153,7 @@ def train(args):
dg_loss_n = exe.run(
dg_program,
feed={'noise': noise_data},
fetch_list={dg_loss})[0][0]
fetch_list=[dg_loss])[0][0]
losses[1].append(dg_loss_n)
t_time += (time.time() - s_time)
if batch_id % 10 == 0:
......@@ -162,7 +162,7 @@ def train(args):
# generate image each batch
generated_images = exe.run(
g_program_test, feed={'noise': const_n},
fetch_list={g_img})[0]
fetch_list=[g_img])[0]
total_images = np.concatenate([real_image, generated_images])
fig = plot(total_images)
msg = "Epoch ID={0} Batch ID={1} D-Loss={2} DG-Loss={3}\n gen={4}".format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册