Auto Commit

上级 ae04d81c
......@@ -334,21 +334,9 @@ torch.save(model.state_dict(), 'model\\' + path +'.pth')
# # 预测结果
pred(test_x , test_y)
#
# ## plot
plot_1()
#
#
#
plt.show(block=True)
from d2l import torch as d2l
# for i in [20,30,50,90,140,320,500]:
# d2l.show_heatmaps(model.attention.attention_weights[i].cpu().reshape((1, 1, 150,150)),xlabel='Keys', ylabel='Queries')
d2l.show_heatmaps(model.attention.attention_weights.mean(axis = 0).cpu().reshape((1, 1, 150,150)),xlabel='Keys', ylabel='Queries')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册