From 03462fc9a577e2dedd042961963eef31993068d1 Mon Sep 17 00:00:00 2001 From: ShusenTang Date: Sun, 23 Feb 2020 12:32:51 +0800 Subject: [PATCH] :hammer: fix bug #97 --- .../3.16_kaggle-house-price.ipynb | 747 +++++++++--------- .../3.16_kaggle-house-price.md | 16 +- 2 files changed, 371 insertions(+), 392 deletions(-) diff --git a/code/chapter03_DL-basics/3.16_kaggle-house-price.ipynb b/code/chapter03_DL-basics/3.16_kaggle-house-price.ipynb index bebe879..2c5c66a 100644 --- a/code/chapter03_DL-basics/3.16_kaggle-house-price.ipynb +++ b/code/chapter03_DL-basics/3.16_kaggle-house-price.ipynb @@ -298,7 +298,7 @@ " with torch.no_grad():\n", " # 将小于1的值设成1,使得取对数时数值更稳定\n", " clipped_preds = torch.max(net(features), torch.tensor(1.0))\n", - " rmse = torch.sqrt(2 * loss(clipped_preds.log(), labels.log()).mean())\n", + " rmse = torch.sqrt(loss(clipped_preds.log(), labels.log()))\n", " return rmse.item()" ] }, @@ -405,12 +405,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "fold 0, train rmse 0.240939, valid rmse 0.221437\n", - "fold 1, train rmse 0.229326, valid rmse 0.267492\n", - "fold 2, train rmse 0.231815, valid rmse 0.237722\n", - "fold 3, train rmse 0.237550, valid rmse 0.219035\n", - "fold 4, train rmse 0.230578, valid rmse 0.258887\n", - "5-fold validation: avg train rmse 0.234042, avg valid rmse 0.240915\n" + "fold 0, train rmse 0.170585, valid rmse 0.156860\n", + "fold 1, train rmse 0.162552, valid rmse 0.190944\n", + "fold 2, train rmse 0.164199, valid rmse 0.168767\n", + "fold 3, train rmse 0.168698, valid rmse 0.154873\n", + "fold 4, train rmse 0.163213, valid rmse 0.183080\n", + "5-fold validation: avg train rmse 0.165849, avg valid rmse 0.170905\n" ] }, { @@ -450,10 +450,10 @@ " \n", " \n", + "\" id=\"m9cbda39ac0\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -489,7 +489,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -529,7 +529,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -562,7 +562,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -608,7 +608,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -663,7 +663,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -852,15 +852,15 @@ " \n", " \n", + "\" id=\"m0df7c1c40c\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -872,94 +872,80 @@ " \n", " \n", + "\" id=\"m8231d37304\" style=\"stroke:#000000;stroke-width:0.6;\"/>\n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1021,210 +1007,210 @@ " \n", " \n", " \n", - " \n", - " \n", + " \n", " \n", - " \n", - " \n", + " \n", " \n", " \n", @@ -1261,12 +1247,12 @@ "z\n", "\" style=\"fill:#ffffff;opacity:0.8;stroke:#cccccc;stroke-linejoin:miter;\"/>\n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1361,12 +1347,12 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1424,14 +1410,14 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ - "" + "" ] }, "metadata": {}, @@ -1469,7 +1455,7 @@ " preds = net(test_features).detach().numpy()\n", " test_data['SalePrice'] = pd.Series(preds.reshape(1, -1)[0])\n", " submission = pd.concat([test_data['Id'], test_data['SalePrice']], axis=1)\n", - " submission.to_csv('./submission.csv', index=False)" + " # submission.to_csv('./submission.csv', index=False)" ] }, { @@ -1481,7 +1467,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "train rmse 0.230200\n" + "train rmse 0.162085\n" ] }, { @@ -1521,10 +1507,10 @@ " \n", " \n", + "\" id=\"me383947859\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1560,7 +1546,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1600,7 +1586,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1633,7 +1619,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1679,7 +1665,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1734,7 +1720,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1923,15 +1909,15 @@ " \n", " \n", + "\" id=\"mf4b47cc8b8\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1943,87 +1929,80 @@ " \n", " \n", + "\" id=\"m5bb3ee9e0a\" style=\"stroke:#000000;stroke-width:0.6;\"/>\n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2085,106 +2064,106 @@ " \n", " \n", " \n", - " \n", - " \n", + " \n", " \n", @@ -2211,14 +2190,14 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ - "" + "" ] }, "metadata": {}, diff --git a/docs/chapter03_DL-basics/3.16_kaggle-house-price.md b/docs/chapter03_DL-basics/3.16_kaggle-house-price.md index bdb8715..751d6d7 100644 --- a/docs/chapter03_DL-basics/3.16_kaggle-house-price.md +++ b/docs/chapter03_DL-basics/3.16_kaggle-house-price.md @@ -131,7 +131,7 @@ def log_rmse(net, features, labels): with torch.no_grad(): # 将小于1的值设成1,使得取对数时数值更稳定 clipped_preds = torch.max(net(features), torch.tensor(1.0)) - rmse = torch.sqrt(2 * loss(clipped_preds.log(), labels.log()).mean()) + rmse = torch.sqrt(loss(clipped_preds.log(), labels.log())) return rmse.item() ``` @@ -203,12 +203,12 @@ def k_fold(k, X_train, y_train, num_epochs, ``` 输出: ``` -fold 0, train rmse 0.241054, valid rmse 0.221462 -fold 1, train rmse 0.229857, valid rmse 0.268489 -fold 2, train rmse 0.231413, valid rmse 0.238157 -fold 3, train rmse 0.237733, valid rmse 0.218747 -fold 4, train rmse 0.230720, valid rmse 0.258712 -5-fold validation: avg train rmse 0.234155, avg valid rmse 0.241113 +fold 0, train rmse 0.170585, valid rmse 0.156860 +fold 1, train rmse 0.162552, valid rmse 0.190944 +fold 2, train rmse 0.164199, valid rmse 0.168767 +fold 3, train rmse 0.168698, valid rmse 0.154873 +fold 4, train rmse 0.163213, valid rmse 0.183080 +5-fold validation: avg train rmse 0.165849, avg valid rmse 0.170905 ``` @@ -250,7 +250,7 @@ train_and_pred(train_features, test_features, train_labels, test_data, num_epoch ``` 输出: ``` -train rmse 0.229943 +train rmse 0.162085 ``` -- GitLab