"labels = data[\"label\"] # The subscript of data picture is the standard for us to judge whether it is correct or not\n",
"labels = data[\"label\"]\n",
"\n",
"\n",
"output =model.predict(Tensor(data['image']))\n",
"output =model.predict(Tensor(data['image']))\n",
"# The predict function returns the probability of 0-9 numbers corresponding to each picture\n",
"prb = output.asnumpy()\n",
"prb = output.asnumpy()\n",
"pred = np.argmax(output.asnumpy(), axis=1)\n",
"pred = np.argmax(output.asnumpy(), axis=1)\n",
"err_num = []\n",
"err_num = []\n",
...
@@ -828,12 +1034,11 @@
...
@@ -828,12 +1034,11 @@
" plt.axis(\"off\")\n",
" plt.axis(\"off\")\n",
" if color == 'red':\n",
" if color == 'red':\n",
" index = 0\n",
" index = 0\n",
" # Print out the wrong data identified by the current group\n",
" print(\"Row {}, column {} is incorrectly identified as {}, the correct value should be {}\".format(int(i/8)+1, i%8+1, pred[i], labels[i]), '\\n')\n",
" print(\"Row {}, column {} is incorrectly identified as {}, the correct value should be {}\".format(int(i/8)+1, i%8+1, pred[i], labels[i]), '\\n')\n",
"if index:\n",
"if index:\n",
" print(\"All the figures in this group are predicted correctly!\")\n",
" print(\"All the figures in this group are predicted correctly!\")\n",
"print(pred, \"<--Predicted figures\") # Print the numbers recognized by each group of pictures\n",
"print(pred, \"<--Predicted figures\") \n",
"print(labels, \"<--The right number\") # Print the subscript corresponding to each group of pictures\n",
"print(labels, \"<--The right number\")\n",
"plt.show()"
"plt.show()"
]
]
},
},
...
@@ -843,29 +1048,68 @@
...
@@ -843,29 +1048,68 @@
"source": [
"source": [
"构建一个概率分析的饼图函数。\n",
"构建一个概率分析的饼图函数。\n",
"\n",
"\n",
"备注:prb为上一段代码中,存储这组数对应的数字概率。"
"备注:`prb`为上一段代码中,存储这组数对应的数字概率。"
]
]
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": null,
"execution_count": 19,
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Figure 1 probability of corresponding numbers [0-9]:\n",