From f715190b4d97ebcc7e37899a611632aa3f6df87e Mon Sep 17 00:00:00 2001 From: 64853c78922a365f6dc3d773 <64853c78922a365f6dc3d773@devide> Date: Sun, 11 Jun 2023 08:04:33 +0000 Subject: [PATCH] Auto Commit --- main.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/main.py b/main.py index 8cdfdc3..9ee15f5 100644 --- a/main.py +++ b/main.py @@ -325,16 +325,12 @@ train_loss , val_loss = train(train_iter ,test_iter, num_epochs , updater , loss # 保存 times = time.localtime() path = "%04d-%02d-%02d_%02d_%02d_%02d" % (times.tm_year, times.tm_mon, times.tm_mday, times.tm_hour, times.tm_min, times.tm_sec) -torch.save(model.state_dict(), 'model\\' + path +'.pth') +torch.save(model.state_dict(), 'model/' + path +'.pth') # 加载 # state_dict = torch.load(r'model\GRU\2023-06-10_10_55_23.pth') # model.load_state_dict(state_dict) # model1.eval() - - - - pred(test_x , test_y) plot_1() d2l.show_heatmaps(model.attention.attention_weights.mean(axis = 0).cpu().reshape((1, 1, 150,150)),xlabel='Keys', ylabel='Queries') -- GitLab