提交 18994ce6 编写于 作者: L LiYanlin

Fixes #1

上级 6e84935e
......@@ -399,7 +399,7 @@
" test_mask_logits = logits[mask]\n",
" predict_y = test_mask_logits.max(1)[1]\n",
" accuarcy = torch.eq(predict_y, tensor_y[mask]).float().mean()\n",
" return accuarcy, test_mask_logits.numpy(), tensor_y[mask].numpy()\n"
" return accuarcy, test_mask_logits.cpu().numpy(), tensor_y[mask].cpu().numpy()\n"
]
},
{
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册