From c2a79feb05f385ff8b98f8274fbb2d9fc7f2363c Mon Sep 17 00:00:00 2001 From: wangna11BD <79366697+wangna11BD@users.noreply.github.com> Date: Fri, 4 Nov 2022 10:26:40 +0800 Subject: [PATCH] fix 0 tensor in readme (#721) --- ...\254\344\272\214\345\244\251\344\275\234\344\270\232.py" | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git "a/education/\347\254\254\344\272\214\345\244\251\344\275\234\344\270\232.py" "b/education/\347\254\254\344\272\214\345\244\251\344\275\234\344\270\232.py" index c2223c0..d995632 100644 --- "a/education/\347\254\254\344\272\214\345\244\251\344\275\234\344\270\232.py" +++ "b/education/\347\254\254\344\272\214\345\244\251\344\275\234\344\270\232.py" @@ -159,7 +159,7 @@ for pass_id in range(100): optimizerD.clear_grad() errD = errD_real + errD_fake - losses[0].append(errD.numpy()[0]) + losses[0].append(float(errD)) ############################ # (2) Update G network: maximize log(D(G(z))) @@ -174,7 +174,7 @@ for pass_id in range(100): optimizerG.step() optimizerG.clear_grad() - losses[1].append(errG.numpy()[0]) + losses[1].append(float(errG)) ############################ @@ -196,7 +196,7 @@ for pass_id in range(100): plt.xticks([]) plt.yticks([]) plt.subplots_adjust(wspace=0.1, hspace=0.1) - msg = 'Epoch ID={0} Batch ID={1} \n\n D-Loss={2} G-Loss={3}'.format(pass_id, batch_id, errD.numpy()[0], errG.numpy()[0]) + msg = 'Epoch ID={0} Batch ID={1} \n\n D-Loss={2} G-Loss={3}'.format(pass_id, batch_id, float(errD), float(errG)) print(msg) plt.suptitle(msg,fontsize=20) plt.draw() -- GitLab