diff --git a/code/chapter05_CNN/5.5_lenet.ipynb b/code/chapter05_CNN/5.5_lenet.ipynb index 6b837decfc141c5be4455230c0755e77be8bc45d..bc797770bbe88f89a887524230fa9aa342774549 100644 --- a/code/chapter05_CNN/5.5_lenet.ipynb +++ b/code/chapter05_CNN/5.5_lenet.ipynb @@ -159,8 +159,10 @@ "outputs": [], "source": [ "# 本函数已保存在d2lzh_pytorch包中方便以后使用。该函数将被逐步改进:它的完整实现将在“图像增广”一节中描述\n", - "def evaluate_accuracy(data_iter, net):\n", - " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "def evaluate_accuracy(data_iter, net, device=None):\n", + " if device is None and isinstance(net, torch.nn.Module):\n", + " # 如果没指定device就使用net的device\n", + " device = list(net.parameters())[0].device\n", " acc_sum, n = 0.0, 0\n", " with torch.no_grad():\n", " for X, y in data_iter:\n", @@ -253,7 +255,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python [default]", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -267,7 +269,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.8" + "version": "3.7.3" }, "varInspector": { "cols": { diff --git a/code/d2lzh_pytorch/utils.py b/code/d2lzh_pytorch/utils.py index 7bcfdc6aa8f0859d29023fed1b37b8d924741b3a..0171e42feaf695b3ec0171f55f6a83a1099cc6ba 100644 --- a/code/d2lzh_pytorch/utils.py +++ b/code/d2lzh_pytorch/utils.py @@ -205,7 +205,7 @@ def corr2d(X, K): # ############################ 5.5 ######################### def evaluate_accuracy(data_iter, net, device=None): - if device is None: + if device is None and isinstance(net, torch.nn.Module): # 如果没指定device就使用net的device device = list(net.parameters())[0].device acc_sum, n = 0.0, 0 diff --git a/docs/chapter05_CNN/5.5_lenet.md b/docs/chapter05_CNN/5.5_lenet.md index b1a1cd247d6911860d4ef56f1360210b2e64ccf1..1341ae046e2787fe768d7bcdf4f241e35ac7d200 100644 --- a/docs/chapter05_CNN/5.5_lenet.md +++ b/docs/chapter05_CNN/5.5_lenet.md @@ -102,8 +102,10 @@ train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size) ``` python # 本函数已保存在d2lzh_pytorch包中方便以后使用。该函数将被逐步改进。 -def evaluate_accuracy(data_iter, net, - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')): +def evaluate_accuracy(data_iter, net, device=None): + if device is None and isinstance(net, torch.nn.Module): + # 如果没指定device就使用net的device + device = list(net.parameters())[0].device acc_sum, n = 0.0, 0 with torch.no_grad(): for X, y in data_iter: