“ae6dae4189edd0c6c087c5b44c5fd02a16c7af77”上不存在“mobile/src/operators/kernel/sum_kernel.h”
未验证 提交 3fa53567 编写于 作者: C Chen Long 提交者: GitHub

fix_test_loader (#889)

* fix_test_loader

* fix_mnist_wrongs
上级 601a9386
......@@ -19,7 +19,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 13,
"metadata": {},
"outputs": [
{
......@@ -46,7 +46,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 14,
"metadata": {},
"outputs": [
{
......@@ -74,7 +74,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 15,
"metadata": {},
"outputs": [
{
......@@ -112,12 +112,12 @@
"metadata": {},
"source": [
"## 组网\n",
"用paddle.nn下的API,如`Conv2d`、`Pool2D`、`Linead`完成LeNet的构建。"
"用paddle.nn下的API,如`Conv2d`、`MaxPool2d`、`Linear`完成LeNet的构建。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
......@@ -141,7 +141,7 @@
" x = F.relu(x)\n",
" x = self.conv2(x)\n",
" x = self.max_pool2(x)\n",
" x = paddle.reshape(x, shape=[-1, 16*5*5])\n",
" x = paddle.flatten(x, start_axis=1,stop_axis=-1)\n",
" x = self.linear1(x)\n",
" x = F.relu(x)\n",
" x = self.linear2(x)\n",
......@@ -161,39 +161,39 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, batch_id: 0, loss is: [2.3029077], acc is: [0.15625]\n",
"epoch: 0, batch_id: 100, loss is: [1.6757016], acc is: [0.84375]\n",
"epoch: 0, batch_id: 200, loss is: [1.5340967], acc is: [0.96875]\n",
"epoch: 0, batch_id: 300, loss is: [1.4943825], acc is: [0.984375]\n",
"epoch: 0, batch_id: 400, loss is: [1.5084226], acc is: [1.]\n",
"epoch: 0, batch_id: 500, loss is: [1.5035012], acc is: [0.984375]\n",
"epoch: 0, batch_id: 600, loss is: [1.4784969], acc is: [0.984375]\n",
"epoch: 0, batch_id: 700, loss is: [1.5656701], acc is: [0.96875]\n",
"epoch: 0, batch_id: 800, loss is: [1.5226105], acc is: [1.]\n",
"epoch: 0, batch_id: 900, loss is: [1.5094678], acc is: [1.]\n",
"epoch: 1, batch_id: 0, loss is: [1.4956206], acc is: [0.984375]\n",
"epoch: 1, batch_id: 100, loss is: [1.4908005], acc is: [1.]\n",
"epoch: 1, batch_id: 200, loss is: [1.485649], acc is: [0.984375]\n",
"epoch: 1, batch_id: 300, loss is: [1.5090752], acc is: [1.]\n",
"epoch: 1, batch_id: 400, loss is: [1.5163708], acc is: [1.]\n",
"epoch: 1, batch_id: 500, loss is: [1.4863018], acc is: [0.984375]\n",
"epoch: 1, batch_id: 600, loss is: [1.4764814], acc is: [0.984375]\n",
"epoch: 1, batch_id: 700, loss is: [1.5496588], acc is: [0.984375]\n",
"epoch: 1, batch_id: 800, loss is: [1.4998187], acc is: [1.]\n",
"epoch: 1, batch_id: 900, loss is: [1.5110929], acc is: [1.]\n"
"epoch: 0, batch_id: 0, loss is: [2.3079572], acc is: [0.125]\n",
"epoch: 0, batch_id: 100, loss is: [1.7078608], acc is: [0.828125]\n",
"epoch: 0, batch_id: 200, loss is: [1.5642334], acc is: [0.90625]\n",
"epoch: 0, batch_id: 300, loss is: [1.7024238], acc is: [0.78125]\n",
"epoch: 0, batch_id: 400, loss is: [1.5536337], acc is: [0.921875]\n",
"epoch: 0, batch_id: 500, loss is: [1.6908336], acc is: [0.828125]\n",
"epoch: 0, batch_id: 600, loss is: [1.5622432], acc is: [0.921875]\n",
"epoch: 0, batch_id: 700, loss is: [1.5251796], acc is: [0.953125]\n",
"epoch: 0, batch_id: 800, loss is: [1.5698484], acc is: [0.890625]\n",
"epoch: 0, batch_id: 900, loss is: [1.5524453], acc is: [0.9375]\n",
"epoch: 1, batch_id: 0, loss is: [1.6443151], acc is: [0.84375]\n",
"epoch: 1, batch_id: 100, loss is: [1.5547533], acc is: [0.90625]\n",
"epoch: 1, batch_id: 200, loss is: [1.5019028], acc is: [1.]\n",
"epoch: 1, batch_id: 300, loss is: [1.4820204], acc is: [1.]\n",
"epoch: 1, batch_id: 400, loss is: [1.5215418], acc is: [0.984375]\n",
"epoch: 1, batch_id: 500, loss is: [1.4972374], acc is: [1.]\n",
"epoch: 1, batch_id: 600, loss is: [1.4930981], acc is: [0.984375]\n",
"epoch: 1, batch_id: 700, loss is: [1.4971689], acc is: [0.984375]\n",
"epoch: 1, batch_id: 800, loss is: [1.4611597], acc is: [1.]\n",
"epoch: 1, batch_id: 900, loss is: [1.4903957], acc is: [0.984375]\n"
]
}
],
"source": [
"import paddle\n",
"train_loader = paddle.io.DataLoader(train_dataset, places=paddle.CPUPlace(), batch_size=64)\n",
"train_loader = paddle.io.DataLoader(train_dataset, places=paddle.CPUPlace(), batch_size=64, shuffle=True)\n",
"# 加载训练集 batch_size 设为 64\n",
"def train(model):\n",
" model.train()\n",
......@@ -229,23 +229,21 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"batch_id: 0, loss is: [1.4929559], acc is: [0.984375]\n",
"batch_id: 100, loss is: [1.4921299], acc is: [0.984375]\n",
"batch_id: 200, loss is: [1.5021144], acc is: [1.]\n",
"batch_id: 300, loss is: [1.4809179], acc is: [0.984375]\n",
"batch_id: 400, loss is: [1.4768506], acc is: [1.]\n",
"batch_id: 500, loss is: [1.4768407], acc is: [1.]\n",
"batch_id: 600, loss is: [1.476671], acc is: [0.984375]\n",
"batch_id: 700, loss is: [1.5093586], acc is: [1.]\n",
"batch_id: 800, loss is: [1.5057312], acc is: [1.]\n",
"batch_id: 900, loss is: [1.4923737], acc is: [1.]\n"
"batch_id: 0, loss is: [1.4767745], acc is: [1.]\n",
"batch_id: 20, loss is: [1.4841802], acc is: [0.984375]\n",
"batch_id: 40, loss is: [1.4997194], acc is: [1.]\n",
"batch_id: 60, loss is: [1.4895413], acc is: [1.]\n",
"batch_id: 80, loss is: [1.4668798], acc is: [1.]\n",
"batch_id: 100, loss is: [1.4611752], acc is: [1.]\n",
"batch_id: 120, loss is: [1.4613602], acc is: [1.]\n",
"batch_id: 140, loss is: [1.4923686], acc is: [1.]\n"
]
}
],
......@@ -256,7 +254,7 @@
"def test(model):\n",
" model.eval()\n",
" batch_size = 64\n",
" for batch_id, data in enumerate(train_loader()):\n",
" for batch_id, data in enumerate(test_loader()):\n",
" x_data = data[0]\n",
" y_data = data[1]\n",
" predicts = model(x_data)\n",
......@@ -266,7 +264,7 @@
" avg_loss = paddle.mean(loss)\n",
" avg_acc = paddle.mean(acc)\n",
" avg_loss.backward()\n",
" if batch_id % 100 == 0:\n",
" if batch_id % 20 == 0:\n",
" print(\"batch_id: {}, loss is: {}, acc is: {}\".format(batch_id, avg_loss.numpy(), avg_acc.numpy()))\n",
"test(model)"
]
......@@ -289,7 +287,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
......@@ -317,7 +315,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 22,
"metadata": {},
"outputs": [
{
......@@ -325,202 +323,28 @@
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"step 10/938 - loss: 1.5564 - acc_top1: 0.7773 - acc_top2: 0.8103 - 17ms/step\n",
"step 20/938 - loss: 1.5538 - acc_top1: 0.7787 - acc_top2: 0.8115 - 16ms/step\n",
"step 30/938 - loss: 1.5591 - acc_top1: 0.7801 - acc_top2: 0.8128 - 16ms/step\n",
"step 40/938 - loss: 1.5234 - acc_top1: 0.7813 - acc_top2: 0.8138 - 16ms/step\n",
"step 50/938 - loss: 1.6375 - acc_top1: 0.7827 - acc_top2: 0.8150 - 16ms/step\n",
"step 60/938 - loss: 1.5435 - acc_top1: 0.7836 - acc_top2: 0.8159 - 16ms/step\n",
"step 70/938 - loss: 1.5900 - acc_top1: 0.7849 - acc_top2: 0.8170 - 15ms/step\n",
"step 80/938 - loss: 1.5130 - acc_top1: 0.7861 - acc_top2: 0.8180 - 15ms/step\n",
"step 90/938 - loss: 1.6275 - acc_top1: 0.7873 - acc_top2: 0.8190 - 15ms/step\n",
"step 100/938 - loss: 1.5574 - acc_top1: 0.7884 - acc_top2: 0.8200 - 15ms/step\n",
"step 110/938 - loss: 1.5883 - acc_top1: 0.7894 - acc_top2: 0.8208 - 15ms/step\n",
"step 120/938 - loss: 1.5808 - acc_top1: 0.7903 - acc_top2: 0.8216 - 15ms/step\n",
"step 130/938 - loss: 1.5924 - acc_top1: 0.7913 - acc_top2: 0.8226 - 15ms/step\n",
"step 140/938 - loss: 1.5238 - acc_top1: 0.7924 - acc_top2: 0.8234 - 15ms/step\n",
"step 150/938 - loss: 1.6007 - acc_top1: 0.7933 - acc_top2: 0.8242 - 15ms/step\n",
"step 160/938 - loss: 1.6028 - acc_top1: 0.7942 - acc_top2: 0.8250 - 15ms/step\n",
"step 170/938 - loss: 1.5838 - acc_top1: 0.7952 - acc_top2: 0.8258 - 14ms/step\n",
"step 180/938 - loss: 1.6334 - acc_top1: 0.7958 - acc_top2: 0.8265 - 14ms/step\n",
"step 190/938 - loss: 1.6356 - acc_top1: 0.7966 - acc_top2: 0.8272 - 14ms/step\n",
"step 200/938 - loss: 1.7113 - acc_top1: 0.7973 - acc_top2: 0.8279 - 14ms/step\n",
"step 210/938 - loss: 1.5355 - acc_top1: 0.7980 - acc_top2: 0.8285 - 14ms/step\n",
"step 220/938 - loss: 1.5510 - acc_top1: 0.7989 - acc_top2: 0.8293 - 14ms/step\n",
"step 230/938 - loss: 1.5542 - acc_top1: 0.7997 - acc_top2: 0.8300 - 14ms/step\n",
"step 240/938 - loss: 1.5730 - acc_top1: 0.8007 - acc_top2: 0.8309 - 14ms/step\n",
"step 250/938 - loss: 1.5378 - acc_top1: 0.8016 - acc_top2: 0.8317 - 14ms/step\n",
"step 260/938 - loss: 1.5517 - acc_top1: 0.8026 - acc_top2: 0.8326 - 14ms/step\n",
"step 270/938 - loss: 1.5762 - acc_top1: 0.8033 - acc_top2: 0.8332 - 14ms/step\n",
"step 280/938 - loss: 1.5611 - acc_top1: 0.8041 - acc_top2: 0.8339 - 14ms/step\n",
"step 290/938 - loss: 1.6558 - acc_top1: 0.8046 - acc_top2: 0.8343 - 14ms/step\n",
"step 300/938 - loss: 1.5836 - acc_top1: 0.8052 - acc_top2: 0.8349 - 14ms/step\n",
"step 310/938 - loss: 1.5332 - acc_top1: 0.8060 - acc_top2: 0.8356 - 14ms/step\n",
"step 320/938 - loss: 1.5105 - acc_top1: 0.8068 - acc_top2: 0.8362 - 14ms/step\n",
"step 330/938 - loss: 1.5325 - acc_top1: 0.8075 - acc_top2: 0.8368 - 14ms/step\n",
"step 340/938 - loss: 1.5779 - acc_top1: 0.8079 - acc_top2: 0.8371 - 14ms/step\n",
"step 350/938 - loss: 1.5903 - acc_top1: 0.8085 - acc_top2: 0.8377 - 14ms/step\n",
"step 360/938 - loss: 1.5884 - acc_top1: 0.8091 - acc_top2: 0.8382 - 14ms/step\n",
"step 370/938 - loss: 1.6248 - acc_top1: 0.8098 - acc_top2: 0.8388 - 14ms/step\n",
"step 380/938 - loss: 1.6995 - acc_top1: 0.8103 - acc_top2: 0.8392 - 14ms/step\n",
"step 390/938 - loss: 1.5695 - acc_top1: 0.8109 - acc_top2: 0.8397 - 14ms/step\n",
"step 400/938 - loss: 1.6015 - acc_top1: 0.8116 - acc_top2: 0.8403 - 14ms/step\n",
"step 410/938 - loss: 1.5643 - acc_top1: 0.8123 - acc_top2: 0.8409 - 14ms/step\n",
"step 420/938 - loss: 1.5745 - acc_top1: 0.8128 - acc_top2: 0.8413 - 14ms/step\n",
"step 430/938 - loss: 1.5517 - acc_top1: 0.8133 - acc_top2: 0.8417 - 14ms/step\n",
"step 440/938 - loss: 1.6097 - acc_top1: 0.8137 - acc_top2: 0.8419 - 14ms/step\n",
"step 450/938 - loss: 1.5700 - acc_top1: 0.8142 - acc_top2: 0.8424 - 14ms/step\n",
"step 460/938 - loss: 1.5696 - acc_top1: 0.8149 - acc_top2: 0.8430 - 14ms/step\n",
"step 470/938 - loss: 1.5713 - acc_top1: 0.8156 - acc_top2: 0.8436 - 14ms/step\n",
"step 480/938 - loss: 1.5909 - acc_top1: 0.8162 - acc_top2: 0.8440 - 14ms/step\n",
"step 490/938 - loss: 1.6004 - acc_top1: 0.8166 - acc_top2: 0.8443 - 14ms/step\n",
"step 500/938 - loss: 1.4905 - acc_top1: 0.8171 - acc_top2: 0.8448 - 14ms/step\n",
"step 510/938 - loss: 1.5842 - acc_top1: 0.8178 - acc_top2: 0.8453 - 14ms/step\n",
"step 520/938 - loss: 1.5237 - acc_top1: 0.8185 - acc_top2: 0.8459 - 14ms/step\n",
"step 530/938 - loss: 1.5130 - acc_top1: 0.8190 - acc_top2: 0.8464 - 14ms/step\n",
"step 540/938 - loss: 1.5345 - acc_top1: 0.8195 - acc_top2: 0.8468 - 14ms/step\n",
"step 550/938 - loss: 1.5513 - acc_top1: 0.8199 - acc_top2: 0.8472 - 14ms/step\n",
"step 560/938 - loss: 1.5704 - acc_top1: 0.8206 - acc_top2: 0.8480 - 14ms/step\n",
"step 570/938 - loss: 1.5481 - acc_top1: 0.8215 - acc_top2: 0.8490 - 14ms/step\n",
"step 580/938 - loss: 1.5087 - acc_top1: 0.8225 - acc_top2: 0.8500 - 14ms/step\n",
"step 590/938 - loss: 1.4844 - acc_top1: 0.8236 - acc_top2: 0.8510 - 14ms/step\n",
"step 600/938 - loss: 1.5162 - acc_top1: 0.8246 - acc_top2: 0.8520 - 14ms/step\n",
"step 610/938 - loss: 1.4730 - acc_top1: 0.8256 - acc_top2: 0.8530 - 14ms/step\n",
"step 620/938 - loss: 1.5118 - acc_top1: 0.8266 - acc_top2: 0.8540 - 14ms/step\n",
"step 630/938 - loss: 1.4760 - acc_top1: 0.8276 - acc_top2: 0.8550 - 14ms/step\n",
"step 640/938 - loss: 1.4942 - acc_top1: 0.8286 - acc_top2: 0.8559 - 14ms/step\n",
"step 650/938 - loss: 1.5077 - acc_top1: 0.8295 - acc_top2: 0.8568 - 14ms/step\n",
"step 660/938 - loss: 1.4973 - acc_top1: 0.8305 - acc_top2: 0.8577 - 14ms/step\n",
"step 670/938 - loss: 1.5091 - acc_top1: 0.8314 - acc_top2: 0.8586 - 14ms/step\n",
"step 680/938 - loss: 1.4692 - acc_top1: 0.8323 - acc_top2: 0.8595 - 14ms/step\n",
"step 690/938 - loss: 1.4746 - acc_top1: 0.8332 - acc_top2: 0.8604 - 14ms/step\n",
"step 700/938 - loss: 1.4619 - acc_top1: 0.8342 - acc_top2: 0.8613 - 14ms/step\n",
"step 710/938 - loss: 1.5545 - acc_top1: 0.8350 - acc_top2: 0.8621 - 14ms/step\n",
"step 720/938 - loss: 1.4629 - acc_top1: 0.8360 - acc_top2: 0.8630 - 14ms/step\n",
"step 730/938 - loss: 1.4738 - acc_top1: 0.8369 - acc_top2: 0.8638 - 14ms/step\n",
"step 740/938 - loss: 1.4804 - acc_top1: 0.8378 - acc_top2: 0.8647 - 14ms/step\n",
"step 750/938 - loss: 1.4808 - acc_top1: 0.8386 - acc_top2: 0.8655 - 14ms/step\n",
"step 760/938 - loss: 1.4695 - acc_top1: 0.8395 - acc_top2: 0.8664 - 14ms/step\n",
"step 770/938 - loss: 1.4899 - acc_top1: 0.8403 - acc_top2: 0.8671 - 14ms/step\n",
"step 780/938 - loss: 1.5244 - acc_top1: 0.8411 - acc_top2: 0.8679 - 14ms/step\n",
"step 790/938 - loss: 1.5005 - acc_top1: 0.8418 - acc_top2: 0.8686 - 14ms/step\n",
"step 800/938 - loss: 1.4614 - acc_top1: 0.8427 - acc_top2: 0.8694 - 14ms/step\n",
"step 810/938 - loss: 1.5227 - acc_top1: 0.8434 - acc_top2: 0.8702 - 14ms/step\n",
"step 820/938 - loss: 1.4852 - acc_top1: 0.8442 - acc_top2: 0.8709 - 14ms/step\n",
"step 830/938 - loss: 1.4978 - acc_top1: 0.8450 - acc_top2: 0.8717 - 14ms/step\n",
"step 840/938 - loss: 1.4686 - acc_top1: 0.8458 - acc_top2: 0.8724 - 14ms/step\n",
"step 850/938 - loss: 1.4729 - acc_top1: 0.8466 - acc_top2: 0.8732 - 14ms/step\n",
"step 860/938 - loss: 1.4715 - acc_top1: 0.8473 - acc_top2: 0.8739 - 14ms/step\n",
"step 870/938 - loss: 1.5115 - acc_top1: 0.8481 - acc_top2: 0.8746 - 14ms/step\n",
"step 880/938 - loss: 1.4826 - acc_top1: 0.8488 - acc_top2: 0.8753 - 14ms/step\n",
"step 890/938 - loss: 1.4619 - acc_top1: 0.8496 - acc_top2: 0.8760 - 14ms/step\n",
"step 900/938 - loss: 1.4775 - acc_top1: 0.8504 - acc_top2: 0.8767 - 14ms/step\n",
"step 910/938 - loss: 1.4614 - acc_top1: 0.8511 - acc_top2: 0.8773 - 14ms/step\n",
"step 920/938 - loss: 1.4762 - acc_top1: 0.8518 - acc_top2: 0.8780 - 14ms/step\n",
"step 930/938 - loss: 1.5233 - acc_top1: 0.8525 - acc_top2: 0.8787 - 14ms/step\n",
"step 938/938 - loss: 1.4616 - acc_top1: 0.8531 - acc_top2: 0.8792 - 14ms/step\n",
"step 100/938 - loss: 1.5644 - acc_top1: 0.6281 - acc_top2: 0.7145 - 14ms/step\n",
"step 200/938 - loss: 1.6221 - acc_top1: 0.7634 - acc_top2: 0.8380 - 13ms/step\n",
"step 300/938 - loss: 1.5123 - acc_top1: 0.8215 - acc_top2: 0.8835 - 13ms/step\n",
"step 400/938 - loss: 1.4791 - acc_top1: 0.8530 - acc_top2: 0.9084 - 13ms/step\n",
"step 500/938 - loss: 1.4904 - acc_top1: 0.8733 - acc_top2: 0.9235 - 13ms/step\n",
"step 600/938 - loss: 1.5101 - acc_top1: 0.8875 - acc_top2: 0.9341 - 13ms/step\n",
"step 700/938 - loss: 1.4642 - acc_top1: 0.8983 - acc_top2: 0.9417 - 13ms/step\n",
"step 800/938 - loss: 1.4789 - acc_top1: 0.9069 - acc_top2: 0.9477 - 13ms/step\n",
"step 900/938 - loss: 1.4773 - acc_top1: 0.9135 - acc_top2: 0.9523 - 13ms/step\n",
"step 938/938 - loss: 1.4714 - acc_top1: 0.9157 - acc_top2: 0.9538 - 13ms/step\n",
"save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/0\n",
"Epoch 2/2\n",
"step 10/938 - loss: 1.5034 - acc_top1: 0.9688 - acc_top2: 0.9891 - 16ms/step\n",
"step 20/938 - loss: 1.4879 - acc_top1: 0.9711 - acc_top2: 0.9898 - 16ms/step\n",
"step 30/938 - loss: 1.4661 - acc_top1: 0.9734 - acc_top2: 0.9906 - 15ms/step\n",
"step 40/938 - loss: 1.5272 - acc_top1: 0.9746 - acc_top2: 0.9914 - 15ms/step\n",
"step 50/938 - loss: 1.4768 - acc_top1: 0.9747 - acc_top2: 0.9919 - 15ms/step\n",
"step 60/938 - loss: 1.4924 - acc_top1: 0.9719 - acc_top2: 0.9898 - 15ms/step\n",
"step 70/938 - loss: 1.4800 - acc_top1: 0.9725 - acc_top2: 0.9906 - 15ms/step\n",
"step 80/938 - loss: 1.5056 - acc_top1: 0.9734 - acc_top2: 0.9914 - 15ms/step\n",
"step 90/938 - loss: 1.4628 - acc_top1: 0.9740 - acc_top2: 0.9915 - 15ms/step\n",
"step 100/938 - loss: 1.4794 - acc_top1: 0.9745 - acc_top2: 0.9919 - 15ms/step\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step 110/938 - loss: 1.4665 - acc_top1: 0.9749 - acc_top2: 0.9918 - 15ms/step\n",
"step 120/938 - loss: 1.4769 - acc_top1: 0.9755 - acc_top2: 0.9917 - 15ms/step\n",
"step 130/938 - loss: 1.4883 - acc_top1: 0.9755 - acc_top2: 0.9918 - 15ms/step\n",
"step 140/938 - loss: 1.4779 - acc_top1: 0.9757 - acc_top2: 0.9920 - 15ms/step\n",
"step 150/938 - loss: 1.4926 - acc_top1: 0.9751 - acc_top2: 0.9922 - 15ms/step\n",
"step 160/938 - loss: 1.5458 - acc_top1: 0.9750 - acc_top2: 0.9924 - 15ms/step\n",
"step 170/938 - loss: 1.5166 - acc_top1: 0.9748 - acc_top2: 0.9924 - 15ms/step\n",
"step 180/938 - loss: 1.4676 - acc_top1: 0.9748 - acc_top2: 0.9923 - 14ms/step\n",
"step 190/938 - loss: 1.4773 - acc_top1: 0.9748 - acc_top2: 0.9924 - 15ms/step\n",
"step 200/938 - loss: 1.4893 - acc_top1: 0.9752 - acc_top2: 0.9928 - 15ms/step\n",
"step 210/938 - loss: 1.5408 - acc_top1: 0.9751 - acc_top2: 0.9926 - 15ms/step\n",
"step 220/938 - loss: 1.4934 - acc_top1: 0.9753 - acc_top2: 0.9925 - 15ms/step\n",
"step 230/938 - loss: 1.5162 - acc_top1: 0.9753 - acc_top2: 0.9925 - 15ms/step\n",
"step 240/938 - loss: 1.5097 - acc_top1: 0.9752 - acc_top2: 0.9926 - 14ms/step\n",
"step 250/938 - loss: 1.5264 - acc_top1: 0.9752 - acc_top2: 0.9927 - 14ms/step\n",
"step 260/938 - loss: 1.4843 - acc_top1: 0.9752 - acc_top2: 0.9926 - 14ms/step\n",
"step 270/938 - loss: 1.4818 - acc_top1: 0.9753 - acc_top2: 0.9927 - 14ms/step\n",
"step 280/938 - loss: 1.4627 - acc_top1: 0.9755 - acc_top2: 0.9925 - 14ms/step\n",
"step 290/938 - loss: 1.4932 - acc_top1: 0.9755 - acc_top2: 0.9923 - 14ms/step\n",
"step 300/938 - loss: 1.4641 - acc_top1: 0.9754 - acc_top2: 0.9924 - 14ms/step\n",
"step 310/938 - loss: 1.4908 - acc_top1: 0.9757 - acc_top2: 0.9925 - 14ms/step\n",
"step 320/938 - loss: 1.4910 - acc_top1: 0.9756 - acc_top2: 0.9926 - 14ms/step\n",
"step 330/938 - loss: 1.4693 - acc_top1: 0.9754 - acc_top2: 0.9925 - 14ms/step\n",
"step 340/938 - loss: 1.4968 - acc_top1: 0.9755 - acc_top2: 0.9925 - 14ms/step\n",
"step 350/938 - loss: 1.4963 - acc_top1: 0.9754 - acc_top2: 0.9923 - 14ms/step\n",
"step 360/938 - loss: 1.5005 - acc_top1: 0.9755 - acc_top2: 0.9925 - 14ms/step\n",
"step 370/938 - loss: 1.4683 - acc_top1: 0.9756 - acc_top2: 0.9925 - 14ms/step\n",
"step 380/938 - loss: 1.5069 - acc_top1: 0.9757 - acc_top2: 0.9925 - 14ms/step\n",
"step 390/938 - loss: 1.4619 - acc_top1: 0.9759 - acc_top2: 0.9926 - 14ms/step\n",
"step 400/938 - loss: 1.4627 - acc_top1: 0.9761 - acc_top2: 0.9928 - 14ms/step\n",
"step 410/938 - loss: 1.5207 - acc_top1: 0.9760 - acc_top2: 0.9928 - 14ms/step\n",
"step 420/938 - loss: 1.5234 - acc_top1: 0.9758 - acc_top2: 0.9928 - 14ms/step\n",
"step 430/938 - loss: 1.4797 - acc_top1: 0.9759 - acc_top2: 0.9930 - 14ms/step\n",
"step 440/938 - loss: 1.4618 - acc_top1: 0.9759 - acc_top2: 0.9929 - 14ms/step\n",
"step 450/938 - loss: 1.4760 - acc_top1: 0.9759 - acc_top2: 0.9930 - 14ms/step\n",
"step 460/938 - loss: 1.4612 - acc_top1: 0.9760 - acc_top2: 0.9929 - 14ms/step\n",
"step 470/938 - loss: 1.4922 - acc_top1: 0.9759 - acc_top2: 0.9929 - 14ms/step\n",
"step 480/938 - loss: 1.4831 - acc_top1: 0.9761 - acc_top2: 0.9930 - 14ms/step\n",
"step 490/938 - loss: 1.4914 - acc_top1: 0.9762 - acc_top2: 0.9931 - 14ms/step\n",
"step 500/938 - loss: 1.4861 - acc_top1: 0.9760 - acc_top2: 0.9931 - 14ms/step\n",
"step 510/938 - loss: 1.4703 - acc_top1: 0.9760 - acc_top2: 0.9931 - 14ms/step\n",
"step 520/938 - loss: 1.5457 - acc_top1: 0.9755 - acc_top2: 0.9930 - 14ms/step\n",
"step 530/938 - loss: 1.4808 - acc_top1: 0.9754 - acc_top2: 0.9929 - 14ms/step\n",
"step 540/938 - loss: 1.4831 - acc_top1: 0.9753 - acc_top2: 0.9929 - 14ms/step\n",
"step 550/938 - loss: 1.5072 - acc_top1: 0.9753 - acc_top2: 0.9929 - 14ms/step\n",
"step 560/938 - loss: 1.4628 - acc_top1: 0.9753 - acc_top2: 0.9929 - 14ms/step\n",
"step 570/938 - loss: 1.4721 - acc_top1: 0.9753 - acc_top2: 0.9929 - 14ms/step\n",
"step 580/938 - loss: 1.4768 - acc_top1: 0.9755 - acc_top2: 0.9929 - 14ms/step\n",
"step 590/938 - loss: 1.4793 - acc_top1: 0.9755 - acc_top2: 0.9930 - 14ms/step\n",
"step 600/938 - loss: 1.4916 - acc_top1: 0.9754 - acc_top2: 0.9929 - 14ms/step\n",
"step 610/938 - loss: 1.4612 - acc_top1: 0.9755 - acc_top2: 0.9930 - 14ms/step\n",
"step 620/938 - loss: 1.5034 - acc_top1: 0.9753 - acc_top2: 0.9930 - 14ms/step\n",
"step 630/938 - loss: 1.4640 - acc_top1: 0.9754 - acc_top2: 0.9930 - 14ms/step\n",
"step 640/938 - loss: 1.4822 - acc_top1: 0.9756 - acc_top2: 0.9931 - 14ms/step\n",
"step 650/938 - loss: 1.4955 - acc_top1: 0.9756 - acc_top2: 0.9932 - 14ms/step\n",
"step 660/938 - loss: 1.4615 - acc_top1: 0.9757 - acc_top2: 0.9933 - 14ms/step\n",
"step 670/938 - loss: 1.5060 - acc_top1: 0.9757 - acc_top2: 0.9932 - 14ms/step\n",
"step 680/938 - loss: 1.4640 - acc_top1: 0.9758 - acc_top2: 0.9933 - 14ms/step\n",
"step 690/938 - loss: 1.5070 - acc_top1: 0.9759 - acc_top2: 0.9933 - 14ms/step\n",
"step 700/938 - loss: 1.4615 - acc_top1: 0.9761 - acc_top2: 0.9933 - 14ms/step\n",
"step 710/938 - loss: 1.5210 - acc_top1: 0.9760 - acc_top2: 0.9933 - 14ms/step\n",
"step 720/938 - loss: 1.5154 - acc_top1: 0.9761 - acc_top2: 0.9933 - 14ms/step\n",
"step 730/938 - loss: 1.4965 - acc_top1: 0.9760 - acc_top2: 0.9933 - 14ms/step\n",
"step 740/938 - loss: 1.4612 - acc_top1: 0.9761 - acc_top2: 0.9933 - 14ms/step\n",
"step 750/938 - loss: 1.4878 - acc_top1: 0.9761 - acc_top2: 0.9934 - 14ms/step\n",
"step 760/938 - loss: 1.4775 - acc_top1: 0.9761 - acc_top2: 0.9933 - 14ms/step\n",
"step 770/938 - loss: 1.4834 - acc_top1: 0.9762 - acc_top2: 0.9933 - 14ms/step\n",
"step 780/938 - loss: 1.4661 - acc_top1: 0.9763 - acc_top2: 0.9934 - 14ms/step\n",
"step 790/938 - loss: 1.4895 - acc_top1: 0.9764 - acc_top2: 0.9934 - 14ms/step\n",
"step 800/938 - loss: 1.4767 - acc_top1: 0.9765 - acc_top2: 0.9934 - 14ms/step\n",
"step 810/938 - loss: 1.4779 - acc_top1: 0.9767 - acc_top2: 0.9935 - 14ms/step\n",
"step 820/938 - loss: 1.4768 - acc_top1: 0.9766 - acc_top2: 0.9935 - 14ms/step\n",
"step 830/938 - loss: 1.4630 - acc_top1: 0.9767 - acc_top2: 0.9935 - 14ms/step\n",
"step 840/938 - loss: 1.4612 - acc_top1: 0.9767 - acc_top2: 0.9936 - 14ms/step\n",
"step 850/938 - loss: 1.4993 - acc_top1: 0.9766 - acc_top2: 0.9935 - 14ms/step\n",
"step 860/938 - loss: 1.4817 - acc_top1: 0.9766 - acc_top2: 0.9936 - 14ms/step\n",
"step 870/938 - loss: 1.4786 - acc_top1: 0.9766 - acc_top2: 0.9935 - 14ms/step\n",
"step 880/938 - loss: 1.4772 - acc_top1: 0.9765 - acc_top2: 0.9935 - 14ms/step\n",
"step 890/938 - loss: 1.4646 - acc_top1: 0.9766 - acc_top2: 0.9936 - 14ms/step\n",
"step 900/938 - loss: 1.4813 - acc_top1: 0.9766 - acc_top2: 0.9936 - 14ms/step\n",
"step 910/938 - loss: 1.4815 - acc_top1: 0.9767 - acc_top2: 0.9936 - 14ms/step\n",
"step 920/938 - loss: 1.4854 - acc_top1: 0.9767 - acc_top2: 0.9936 - 14ms/step\n",
"step 930/938 - loss: 1.4890 - acc_top1: 0.9769 - acc_top2: 0.9936 - 14ms/step\n",
"step 938/938 - loss: 1.4617 - acc_top1: 0.9769 - acc_top2: 0.9937 - 14ms/step\n",
"step 100/938 - loss: 1.4863 - acc_top1: 0.9695 - acc_top2: 0.9897 - 13ms/step\n",
"step 200/938 - loss: 1.4883 - acc_top1: 0.9707 - acc_top2: 0.9912 - 13ms/step\n",
"step 300/938 - loss: 1.4695 - acc_top1: 0.9720 - acc_top2: 0.9910 - 13ms/step\n",
"step 400/938 - loss: 1.4628 - acc_top1: 0.9720 - acc_top2: 0.9915 - 13ms/step\n",
"step 500/938 - loss: 1.5079 - acc_top1: 0.9727 - acc_top2: 0.9918 - 13ms/step\n",
"step 600/938 - loss: 1.4803 - acc_top1: 0.9727 - acc_top2: 0.9919 - 13ms/step\n",
"step 700/938 - loss: 1.4612 - acc_top1: 0.9732 - acc_top2: 0.9923 - 13ms/step\n",
"step 800/938 - loss: 1.4755 - acc_top1: 0.9732 - acc_top2: 0.9923 - 13ms/step\n",
"step 900/938 - loss: 1.4698 - acc_top1: 0.9732 - acc_top2: 0.9922 - 13ms/step\n",
"step 938/938 - loss: 1.4764 - acc_top1: 0.9734 - acc_top2: 0.9923 - 13ms/step\n",
"save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/1\n",
"save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/final\n"
]
......@@ -530,6 +354,7 @@
"model.fit(train_dataset,\n",
" epochs=2,\n",
" batch_size=64,\n",
" log_freq=100,\n",
" save_dir='mnist_checkpoint')"
]
},
......@@ -542,7 +367,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 23,
"metadata": {},
"outputs": [
{
......@@ -550,32 +375,32 @@
"output_type": "stream",
"text": [
"Eval begin...\n",
"step 10/157 - loss: 1.5023 - acc_top1: 0.9781 - acc_top2: 0.9969 - 7ms/step\n",
"step 20/157 - loss: 1.5326 - acc_top1: 0.9750 - acc_top2: 0.9953 - 7ms/step\n",
"step 30/157 - loss: 1.4881 - acc_top1: 0.9745 - acc_top2: 0.9943 - 7ms/step\n",
"step 40/157 - loss: 1.4703 - acc_top1: 0.9715 - acc_top2: 0.9934 - 6ms/step\n",
"step 50/157 - loss: 1.4793 - acc_top1: 0.9728 - acc_top2: 0.9934 - 6ms/step\n",
"step 60/157 - loss: 1.5338 - acc_top1: 0.9721 - acc_top2: 0.9924 - 6ms/step\n",
"step 70/157 - loss: 1.4801 - acc_top1: 0.9721 - acc_top2: 0.9922 - 6ms/step\n",
"step 80/157 - loss: 1.4763 - acc_top1: 0.9725 - acc_top2: 0.9928 - 6ms/step\n",
"step 90/157 - loss: 1.4682 - acc_top1: 0.9747 - acc_top2: 0.9936 - 6ms/step\n",
"step 100/157 - loss: 1.4780 - acc_top1: 0.9758 - acc_top2: 0.9939 - 6ms/step\n",
"step 110/157 - loss: 1.4686 - acc_top1: 0.9763 - acc_top2: 0.9942 - 6ms/step\n",
"step 120/157 - loss: 1.4624 - acc_top1: 0.9780 - acc_top2: 0.9947 - 6ms/step\n",
"step 130/157 - loss: 1.4968 - acc_top1: 0.9787 - acc_top2: 0.9948 - 6ms/step\n",
"step 140/157 - loss: 1.4612 - acc_top1: 0.9798 - acc_top2: 0.9952 - 6ms/step\n",
"step 150/157 - loss: 1.4613 - acc_top1: 0.9806 - acc_top2: 0.9955 - 6ms/step\n",
"step 157/157 - loss: 1.4612 - acc_top1: 0.9803 - acc_top2: 0.9955 - 6ms/step\n",
"step 10/157 - loss: 1.5238 - acc_top1: 0.9750 - acc_top2: 0.9938 - 7ms/step\n",
"step 20/157 - loss: 1.5143 - acc_top1: 0.9727 - acc_top2: 0.9922 - 7ms/step\n",
"step 30/157 - loss: 1.5290 - acc_top1: 0.9698 - acc_top2: 0.9932 - 7ms/step\n",
"step 40/157 - loss: 1.4624 - acc_top1: 0.9684 - acc_top2: 0.9930 - 7ms/step\n",
"step 50/157 - loss: 1.4771 - acc_top1: 0.9697 - acc_top2: 0.9925 - 7ms/step\n",
"step 60/157 - loss: 1.5066 - acc_top1: 0.9701 - acc_top2: 0.9922 - 6ms/step\n",
"step 70/157 - loss: 1.4804 - acc_top1: 0.9699 - acc_top2: 0.9920 - 6ms/step\n",
"step 80/157 - loss: 1.4718 - acc_top1: 0.9707 - acc_top2: 0.9930 - 6ms/step\n",
"step 90/157 - loss: 1.4874 - acc_top1: 0.9726 - acc_top2: 0.9934 - 6ms/step\n",
"step 100/157 - loss: 1.4612 - acc_top1: 0.9736 - acc_top2: 0.9936 - 6ms/step\n",
"step 110/157 - loss: 1.4612 - acc_top1: 0.9746 - acc_top2: 0.9938 - 6ms/step\n",
"step 120/157 - loss: 1.4763 - acc_top1: 0.9763 - acc_top2: 0.9941 - 6ms/step\n",
"step 130/157 - loss: 1.4786 - acc_top1: 0.9764 - acc_top2: 0.9935 - 6ms/step\n",
"step 140/157 - loss: 1.4612 - acc_top1: 0.9775 - acc_top2: 0.9939 - 6ms/step\n",
"step 150/157 - loss: 1.4894 - acc_top1: 0.9785 - acc_top2: 0.9943 - 6ms/step\n",
"step 157/157 - loss: 1.4612 - acc_top1: 0.9777 - acc_top2: 0.9941 - 6ms/step\n",
"Eval samples: 10000\n"
]
},
{
"data": {
"text/plain": [
"{'loss': [1.4611506], 'acc_top1': 0.9803, 'acc_top2': 0.9955}"
"{'loss': [1.4611504], 'acc_top1': 0.9777, 'acc_top2': 0.9941}"
]
},
"execution_count": 12,
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册