diff --git a/main.py b/main.py index 8cdfdc3428666bde8c9d91505b06d865d1575929..9ee15f5cc77b993b2bc53cd928a31fa1c3cb3eef 100644 --- a/main.py +++ b/main.py @@ -325,16 +325,12 @@ train_loss , val_loss = train(train_iter ,test_iter, num_epochs , updater , loss # 保存 times = time.localtime() path = "%04d-%02d-%02d_%02d_%02d_%02d" % (times.tm_year, times.tm_mon, times.tm_mday, times.tm_hour, times.tm_min, times.tm_sec) -torch.save(model.state_dict(), 'model\\' + path +'.pth') +torch.save(model.state_dict(), 'model/' + path +'.pth') # 加载 # state_dict = torch.load(r'model\GRU\2023-06-10_10_55_23.pth') # model.load_state_dict(state_dict) # model1.eval() - - - - pred(test_x , test_y) plot_1() d2l.show_heatmaps(model.attention.attention_weights.mean(axis = 0).cpu().reshape((1, 1, 150,150)),xlabel='Keys', ylabel='Queries')