未验证 提交 d42b46ca 编写于 作者: W wangna11BD 提交者: GitHub

fix zero tensor (#719)

* fix 0 tensor
上级 5f9c92e3
......@@ -182,7 +182,7 @@ for pass_id in range(100):
optimizerD.clear_grad()
errD = errD_real + errD_fake
losses[0].append(errD.numpy()[0])
losses[0].append(float(errD))
############################
# (2) Update G network: maximize log(D(G(z)))
......@@ -197,7 +197,7 @@ for pass_id in range(100):
optimizerG.step()
optimizerG.clear_grad()
losses[1].append(errG.numpy()[0])
losses[1].append(float(errG))
############################
......@@ -219,7 +219,7 @@ for pass_id in range(100):
plt.xticks([])
plt.yticks([])
plt.subplots_adjust(wspace=0.1, hspace=0.1)
msg = 'Epoch ID={0} Batch ID={1} \n\n D-Loss={2} G-Loss={3}'.format(pass_id, batch_id, errD.numpy()[0], errG.numpy()[0])
msg = 'Epoch ID={0} Batch ID={1} \n\n D-Loss={2} G-Loss={3}'.format(pass_id, batch_id, float(errD), float(errG))
print(msg)
plt.suptitle(msg,fontsize=20)
plt.draw()
......
......@@ -144,8 +144,8 @@ class StyleGANv2FittingPredictor(StyleGANv2Predictor):
optimizer.step()
pbar.set_description(
(f"perceptual: {p_loss.numpy()[0]:.4f}; "
f"mse: {mse_loss.numpy()[0]:.4f}; lr: {lr:.4f}"))
(f"perceptual: {float(p_loss):.4f}; "
f"mse: {float(mse_loss):.4f}; lr: {lr:.4f}"))
img_gen, _ = generator([latent_n],
input_is_latent=True,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册