未验证 提交 55c5db2f 编写于 作者: C chengduo 提交者: GitHub

fix fetch_list type (#786)

上级 a8f1b403
...@@ -124,7 +124,7 @@ def train(args): ...@@ -124,7 +124,7 @@ def train(args):
total_label = np.concatenate([real_labels, fake_labels]) total_label = np.concatenate([real_labels, fake_labels])
s_time = time.time() s_time = time.time()
generated_image = exe.run( generated_image = exe.run(
g_program, feed={'noise': noise_data}, fetch_list={g_img})[0] g_program, feed={'noise': noise_data}, fetch_list=[g_img])[0]
total_images = np.concatenate([real_image, generated_image]) total_images = np.concatenate([real_image, generated_image])
...@@ -134,7 +134,7 @@ def train(args): ...@@ -134,7 +134,7 @@ def train(args):
'img': generated_image, 'img': generated_image,
'label': fake_labels, 'label': fake_labels,
}, },
fetch_list={d_loss})[0][0] fetch_list=[d_loss])[0][0]
d_loss_2 = exe.run( d_loss_2 = exe.run(
d_program, d_program,
...@@ -142,7 +142,7 @@ def train(args): ...@@ -142,7 +142,7 @@ def train(args):
'img': real_image, 'img': real_image,
'label': real_labels, 'label': real_labels,
}, },
fetch_list={d_loss})[0][0] fetch_list=[d_loss])[0][0]
d_loss_n = d_loss_1 + d_loss_2 d_loss_n = d_loss_1 + d_loss_2
losses[0].append(d_loss_n) losses[0].append(d_loss_n)
...@@ -153,7 +153,7 @@ def train(args): ...@@ -153,7 +153,7 @@ def train(args):
dg_loss_n = exe.run( dg_loss_n = exe.run(
dg_program, dg_program,
feed={'noise': noise_data}, feed={'noise': noise_data},
fetch_list={dg_loss})[0][0] fetch_list=[dg_loss])[0][0]
losses[1].append(dg_loss_n) losses[1].append(dg_loss_n)
t_time += (time.time() - s_time) t_time += (time.time() - s_time)
if batch_id % 10 == 0: if batch_id % 10 == 0:
...@@ -162,7 +162,7 @@ def train(args): ...@@ -162,7 +162,7 @@ def train(args):
# generate image each batch # generate image each batch
generated_images = exe.run( generated_images = exe.run(
g_program_test, feed={'noise': const_n}, g_program_test, feed={'noise': const_n},
fetch_list={g_img})[0] fetch_list=[g_img])[0]
total_images = np.concatenate([real_image, generated_images]) total_images = np.concatenate([real_image, generated_images])
fig = plot(total_images) fig = plot(total_images)
msg = "Epoch ID={0} Batch ID={1} D-Loss={2} DG-Loss={3}\n gen={4}".format( msg = "Epoch ID={0} Batch ID={1} D-Loss={2} DG-Loss={3}\n gen={4}".format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册