提交 be2b6497 编写于 作者: S ShusenTang

fix bug in d2l evaluate_accuracy

上级 00a2537e
......@@ -204,8 +204,10 @@ def corr2d(X, K):
# ############################ 5.5 #########################
def evaluate_accuracy(data_iter, net,
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
def evaluate_accuracy(data_iter, net, device=None):
if device is None:
# 如果没指定device就使用net的device
device = list(net.parameters())[0].device
acc_sum, n = 0.0, 0
with torch.no_grad():
for X, y in data_iter:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册