diff --git a/main.py b/main.py index 9780a2302de1f7ed474e7d4f8e5ebc00d2848a65..8cdfdc3428666bde8c9d91505b06d865d1575929 100644 --- a/main.py +++ b/main.py @@ -334,21 +334,9 @@ torch.save(model.state_dict(), 'model\\' + path +'.pth') -# # 预测结果 + pred(test_x , test_y) -# -# ## plot plot_1() -# -# -# - -plt.show(block=True) -from d2l import torch as d2l -# for i in [20,30,50,90,140,320,500]: -# d2l.show_heatmaps(model.attention.attention_weights[i].cpu().reshape((1, 1, 150,150)),xlabel='Keys', ylabel='Queries') - - d2l.show_heatmaps(model.attention.attention_weights.mean(axis = 0).cpu().reshape((1, 1, 150,150)),xlabel='Keys', ylabel='Queries')