From d041e045ee6f820f48ceca18acbc430b3b298dc9 Mon Sep 17 00:00:00 2001 From: ShusenTang Date: Wed, 16 Oct 2019 15:14:27 +0900 Subject: [PATCH] fix bug: issue #29 #30 --- code/chapter05_CNN/5.5_lenet.ipynb | 10 ++++++---- code/d2lzh_pytorch/utils.py | 2 +- docs/chapter05_CNN/5.5_lenet.md | 6 ++++-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/code/chapter05_CNN/5.5_lenet.ipynb b/code/chapter05_CNN/5.5_lenet.ipynb index 6b837de..bc79777 100644 --- a/code/chapter05_CNN/5.5_lenet.ipynb +++ b/code/chapter05_CNN/5.5_lenet.ipynb @@ -159,8 +159,10 @@ "outputs": [], "source": [ "# 本函数已保存在d2lzh_pytorch包中方便以后使用。该函数将被逐步改进:它的完整实现将在“图像增广”一节中描述\n", - "def evaluate_accuracy(data_iter, net):\n", - " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "def evaluate_accuracy(data_iter, net, device=None):\n", + " if device is None and isinstance(net, torch.nn.Module):\n", + " # 如果没指定device就使用net的device\n", + " device = list(net.parameters())[0].device\n", " acc_sum, n = 0.0, 0\n", " with torch.no_grad():\n", " for X, y in data_iter:\n", @@ -253,7 +255,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python [default]", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -267,7 +269,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.8" + "version": "3.7.3" }, "varInspector": { "cols": { diff --git a/code/d2lzh_pytorch/utils.py b/code/d2lzh_pytorch/utils.py index 7bcfdc6..0171e42 100644 --- a/code/d2lzh_pytorch/utils.py +++ b/code/d2lzh_pytorch/utils.py @@ -205,7 +205,7 @@ def corr2d(X, K): # ############################ 5.5 ######################### def evaluate_accuracy(data_iter, net, device=None): - if device is None: + if device is None and isinstance(net, torch.nn.Module): # 如果没指定device就使用net的device device = list(net.parameters())[0].device acc_sum, n = 0.0, 0 diff --git a/docs/chapter05_CNN/5.5_lenet.md b/docs/chapter05_CNN/5.5_lenet.md index b1a1cd2..1341ae0 100644 --- a/docs/chapter05_CNN/5.5_lenet.md +++ b/docs/chapter05_CNN/5.5_lenet.md @@ -102,8 +102,10 @@ train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size) ``` python # 本函数已保存在d2lzh_pytorch包中方便以后使用。该函数将被逐步改进。 -def evaluate_accuracy(data_iter, net, - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')): +def evaluate_accuracy(data_iter, net, device=None): + if device is None and isinstance(net, torch.nn.Module): + # 如果没指定device就使用net的device + device = list(net.parameters())[0].device acc_sum, n = 0.0, 0 with torch.no_grad(): for X, y in data_iter: -- GitLab