提交 d041e045 编写于 作者: S ShusenTang

fix bug: issue #29 #30

上级 be2b6497
......@@ -159,8 +159,10 @@
"outputs": [],
"source": [
"# 本函数已保存在d2lzh_pytorch包中方便以后使用。该函数将被逐步改进:它的完整实现将在“图像增广”一节中描述\n",
"def evaluate_accuracy(data_iter, net):\n",
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"def evaluate_accuracy(data_iter, net, device=None):\n",
" if device is None and isinstance(net, torch.nn.Module):\n",
" # 如果没指定device就使用net的device\n",
" device = list(net.parameters())[0].device\n",
" acc_sum, n = 0.0, 0\n",
" with torch.no_grad():\n",
" for X, y in data_iter:\n",
......@@ -253,7 +255,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
......@@ -267,7 +269,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
"version": "3.7.3"
},
"varInspector": {
"cols": {
......
......@@ -205,7 +205,7 @@ def corr2d(X, K):
# ############################ 5.5 #########################
def evaluate_accuracy(data_iter, net, device=None):
if device is None:
if device is None and isinstance(net, torch.nn.Module):
# 如果没指定device就使用net的device
device = list(net.parameters())[0].device
acc_sum, n = 0.0, 0
......
......@@ -102,8 +102,10 @@ train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
``` python
# 本函数已保存在d2lzh_pytorch包中方便以后使用。该函数将被逐步改进。
def evaluate_accuracy(data_iter, net,
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
def evaluate_accuracy(data_iter, net, device=None):
if device is None and isinstance(net, torch.nn.Module):
# 如果没指定device就使用net的device
device = list(net.parameters())[0].device
acc_sum, n = 0.0, 0
with torch.no_grad():
for X, y in data_iter:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册