diff --git "a/education/\347\254\254\344\272\214\345\244\251\344\275\234\344\270\232.py" "b/education/\347\254\254\344\272\214\345\244\251\344\275\234\344\270\232.py" index c2223c0e04146886e16d7a99b44ddd0e5ebdeb35..d995632a5e760b9a428afe999792101fb8e891a2 100644 --- "a/education/\347\254\254\344\272\214\345\244\251\344\275\234\344\270\232.py" +++ "b/education/\347\254\254\344\272\214\345\244\251\344\275\234\344\270\232.py" @@ -159,7 +159,7 @@ for pass_id in range(100): optimizerD.clear_grad() errD = errD_real + errD_fake - losses[0].append(errD.numpy()[0]) + losses[0].append(float(errD)) ############################ # (2) Update G network: maximize log(D(G(z))) @@ -174,7 +174,7 @@ for pass_id in range(100): optimizerG.step() optimizerG.clear_grad() - losses[1].append(errG.numpy()[0]) + losses[1].append(float(errG)) ############################ @@ -196,7 +196,7 @@ for pass_id in range(100): plt.xticks([]) plt.yticks([]) plt.subplots_adjust(wspace=0.1, hspace=0.1) - msg = 'Epoch ID={0} Batch ID={1} \n\n D-Loss={2} G-Loss={3}'.format(pass_id, batch_id, errD.numpy()[0], errG.numpy()[0]) + msg = 'Epoch ID={0} Batch ID={1} \n\n D-Loss={2} G-Loss={3}'.format(pass_id, batch_id, float(errD), float(errG)) print(msg) plt.suptitle(msg,fontsize=20) plt.draw()