Auto Commit

上级 97397d95
......@@ -325,16 +325,12 @@ train_loss , val_loss = train(train_iter ,test_iter, num_epochs , updater , loss
# 保存
times = time.localtime()
path = "%04d-%02d-%02d_%02d_%02d_%02d" % (times.tm_year, times.tm_mon, times.tm_mday, times.tm_hour, times.tm_min, times.tm_sec)
torch.save(model.state_dict(), 'model\\' + path +'.pth')
torch.save(model.state_dict(), 'model/' + path +'.pth')
# 加载
# state_dict = torch.load(r'model\GRU\2023-06-10_10_55_23.pth')
# model.load_state_dict(state_dict)
# model1.eval()
pred(test_x , test_y)
plot_1()
d2l.show_heatmaps(model.attention.attention_weights.mean(axis = 0).cpu().reshape((1, 1, 150,150)),xlabel='Keys', ylabel='Queries')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册