提交 54a9f315 编写于 作者: T TC.Long

fix model loss

上级 8c1aa4e6
......@@ -19,7 +19,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 1,
"metadata": {},
"outputs": [
{
......@@ -47,7 +47,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"outputs": [
{
......@@ -75,7 +75,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [
{
......@@ -118,7 +118,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
......@@ -162,33 +162,23 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, batch_id: 0, loss is: [2.3037894], acc is: [0.140625]\n",
"epoch: 0, batch_id: 100, loss is: [1.6175328], acc is: [0.9375]\n",
"epoch: 0, batch_id: 200, loss is: [1.5388051], acc is: [0.96875]\n",
"epoch: 0, batch_id: 300, loss is: [1.5251061], acc is: [0.96875]\n",
"epoch: 0, batch_id: 400, loss is: [1.4678856], acc is: [1.]\n",
"epoch: 0, batch_id: 500, loss is: [1.4944503], acc is: [0.984375]\n",
"epoch: 0, batch_id: 600, loss is: [1.5365536], acc is: [0.96875]\n",
"epoch: 0, batch_id: 700, loss is: [1.4885054], acc is: [0.984375]\n",
"epoch: 0, batch_id: 800, loss is: [1.4872254], acc is: [0.984375]\n",
"epoch: 0, batch_id: 900, loss is: [1.4884174], acc is: [0.984375]\n",
"epoch: 1, batch_id: 0, loss is: [1.4776722], acc is: [1.]\n",
"epoch: 1, batch_id: 100, loss is: [1.4751343], acc is: [1.]\n",
"epoch: 1, batch_id: 200, loss is: [1.4772581], acc is: [1.]\n",
"epoch: 1, batch_id: 300, loss is: [1.4918218], acc is: [0.984375]\n",
"epoch: 1, batch_id: 400, loss is: [1.5038397], acc is: [0.96875]\n",
"epoch: 1, batch_id: 500, loss is: [1.5088196], acc is: [0.96875]\n",
"epoch: 1, batch_id: 600, loss is: [1.4961376], acc is: [0.984375]\n",
"epoch: 1, batch_id: 700, loss is: [1.4755756], acc is: [1.]\n",
"epoch: 1, batch_id: 800, loss is: [1.4921497], acc is: [0.984375]\n",
"epoch: 1, batch_id: 900, loss is: [1.4944404], acc is: [1.]\n"
"epoch: 0, batch_id: 0, loss is: [2.3017962], acc is: [0.28125]\n",
"epoch: 0, batch_id: 200, loss is: [1.5294291], acc is: [0.96875]\n",
"epoch: 0, batch_id: 400, loss is: [1.4693298], acc is: [1.]\n",
"epoch: 0, batch_id: 600, loss is: [1.5237448], acc is: [0.984375]\n",
"epoch: 0, batch_id: 800, loss is: [1.4795951], acc is: [0.984375]\n",
"epoch: 1, batch_id: 0, loss is: [1.5161536], acc is: [0.96875]\n",
"epoch: 1, batch_id: 200, loss is: [1.4763479], acc is: [1.]\n",
"epoch: 1, batch_id: 400, loss is: [1.4929678], acc is: [1.]\n",
"epoch: 1, batch_id: 600, loss is: [1.4999642], acc is: [1.]\n",
"epoch: 1, batch_id: 800, loss is: [1.5029153], acc is: [0.984375]\n"
]
}
],
......@@ -209,11 +199,9 @@
" loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n",
" # 计算损失\n",
" acc = paddle.metric.accuracy(predicts, y_data, k=2)\n",
" avg_loss = paddle.mean(loss)\n",
" avg_acc = paddle.mean(acc)\n",
" avg_loss.backward()\n",
" if batch_id % 100 == 0:\n",
" print(\"epoch: {}, batch_id: {}, loss is: {}, acc is: {}\".format(epoch, batch_id, avg_loss.numpy(), avg_acc.numpy()))\n",
" loss.backward()\n",
" if batch_id % 200 == 0:\n",
" print(\"epoch: {}, batch_id: {}, loss is: {}, acc is: {}\".format(epoch, batch_id, loss.numpy(), acc.numpy()))\n",
" optim.step()\n",
" optim.clear_grad()\n",
"model = LeNet()\n",
......@@ -230,21 +218,21 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"batch_id: 0, loss is: [1.4915928], acc is: [1.]\n",
"batch_id: 20, loss is: [1.4818308], acc is: [1.]\n",
"batch_id: 40, loss is: [1.5006062], acc is: [0.984375]\n",
"batch_id: 60, loss is: [1.521233], acc is: [1.]\n",
"batch_id: 80, loss is: [1.4772738], acc is: [1.]\n",
"batch_id: 100, loss is: [1.4755945], acc is: [1.]\n",
"batch_id: 120, loss is: [1.4746133], acc is: [1.]\n",
"batch_id: 140, loss is: [1.4786345], acc is: [1.]\n"
"batch_id: 0, loss is: [1.4616354], acc is: [1.]\n",
"batch_id: 20, loss is: [1.4927294], acc is: [0.984375]\n",
"batch_id: 40, loss is: [1.4990321], acc is: [1.]\n",
"batch_id: 60, loss is: [1.4892884], acc is: [1.]\n",
"batch_id: 80, loss is: [1.4767071], acc is: [1.]\n",
"batch_id: 100, loss is: [1.4611524], acc is: [1.]\n",
"batch_id: 120, loss is: [1.4613531], acc is: [1.]\n",
"batch_id: 140, loss is: [1.4928315], acc is: [1.]\n"
]
}
],
......@@ -262,11 +250,9 @@
" # 获取预测结果\n",
" loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n",
" acc = paddle.metric.accuracy(predicts, y_data, k=2)\n",
" avg_loss = paddle.mean(loss)\n",
" avg_acc = paddle.mean(acc)\n",
" avg_loss.backward()\n",
" loss.backward()\n",
" if batch_id % 20 == 0:\n",
" print(\"batch_id: {}, loss is: {}, acc is: {}\".format(batch_id, avg_loss.numpy(), avg_acc.numpy()))\n",
" print(\"batch_id: {}, loss is: {}, acc is: {}\".format(batch_id, loss.numpy(), acc.numpy()))\n",
"test(model)"
]
},
......@@ -288,7 +274,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
......@@ -316,7 +302,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 18,
"metadata": {},
"outputs": [
{
......@@ -324,17 +310,17 @@
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"step 200/938 - loss: 1.5219 - acc_top1: 0.9829 - acc_top2: 0.9965 - 14ms/step\n",
"step 400/938 - loss: 1.4765 - acc_top1: 0.9825 - acc_top2: 0.9958 - 13ms/step\n",
"step 600/938 - loss: 1.4624 - acc_top1: 0.9823 - acc_top2: 0.9953 - 13ms/step\n",
"step 800/938 - loss: 1.4768 - acc_top1: 0.9829 - acc_top2: 0.9955 - 13ms/step\n",
"step 938/938 - loss: 1.4612 - acc_top1: 0.9836 - acc_top2: 0.9956 - 13ms/step\n",
"step 200/938 - loss: 1.4868 - acc_top1: 0.9805 - acc_top2: 0.9951 - 14ms/step\n",
"step 400/938 - loss: 1.4643 - acc_top1: 0.9802 - acc_top2: 0.9944 - 14ms/step\n",
"step 600/938 - loss: 1.4638 - acc_top1: 0.9799 - acc_top2: 0.9942 - 13ms/step\n",
"step 800/938 - loss: 1.4767 - acc_top1: 0.9801 - acc_top2: 0.9944 - 13ms/step\n",
"step 938/938 - loss: 1.4614 - acc_top1: 0.9804 - acc_top2: 0.9945 - 13ms/step\n",
"Epoch 2/2\n",
"step 200/938 - loss: 1.4705 - acc_top1: 0.9834 - acc_top2: 0.9959 - 13ms/step\n",
"step 400/938 - loss: 1.4620 - acc_top1: 0.9833 - acc_top2: 0.9960 - 13ms/step\n",
"step 600/938 - loss: 1.4613 - acc_top1: 0.9830 - acc_top2: 0.9960 - 13ms/step\n",
"step 800/938 - loss: 1.4763 - acc_top1: 0.9831 - acc_top2: 0.9960 - 13ms/step\n",
"step 938/938 - loss: 1.4924 - acc_top1: 0.9834 - acc_top2: 0.9959 - 13ms/step\n"
"step 200/938 - loss: 1.4618 - acc_top1: 0.9812 - acc_top2: 0.9956 - 13ms/step\n",
"step 400/938 - loss: 1.4778 - acc_top1: 0.9804 - acc_top2: 0.9952 - 13ms/step\n",
"step 600/938 - loss: 1.4698 - acc_top1: 0.9810 - acc_top2: 0.9954 - 13ms/step\n",
"step 800/938 - loss: 1.4621 - acc_top1: 0.9815 - acc_top2: 0.9957 - 13ms/step\n",
"step 938/938 - loss: 1.4847 - acc_top1: 0.9814 - acc_top2: 0.9958 - 13ms/step\n"
]
}
],
......@@ -355,7 +341,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 19,
"metadata": {},
"outputs": [
{
......@@ -363,24 +349,24 @@
"output_type": "stream",
"text": [
"Eval begin...\n",
"step 20/157 - loss: 1.5246 - acc_top1: 0.9773 - acc_top2: 0.9969 - 6ms/step\n",
"step 40/157 - loss: 1.4622 - acc_top1: 0.9758 - acc_top2: 0.9961 - 6ms/step\n",
"step 60/157 - loss: 1.5241 - acc_top1: 0.9763 - acc_top2: 0.9951 - 6ms/step\n",
"step 80/157 - loss: 1.4612 - acc_top1: 0.9787 - acc_top2: 0.9959 - 6ms/step\n",
"step 100/157 - loss: 1.4612 - acc_top1: 0.9823 - acc_top2: 0.9967 - 5ms/step\n",
"step 120/157 - loss: 1.4612 - acc_top1: 0.9835 - acc_top2: 0.9966 - 5ms/step\n",
"step 140/157 - loss: 1.4612 - acc_top1: 0.9844 - acc_top2: 0.9969 - 5ms/step\n",
"step 157/157 - loss: 1.4612 - acc_top1: 0.9838 - acc_top2: 0.9966 - 5ms/step\n",
"step 20/157 - loss: 1.5160 - acc_top1: 0.9805 - acc_top2: 0.9930 - 7ms/step\n",
"step 40/157 - loss: 1.4612 - acc_top1: 0.9793 - acc_top2: 0.9949 - 6ms/step\n",
"step 60/157 - loss: 1.5095 - acc_top1: 0.9792 - acc_top2: 0.9943 - 6ms/step\n",
"step 80/157 - loss: 1.4612 - acc_top1: 0.9785 - acc_top2: 0.9941 - 6ms/step\n",
"step 100/157 - loss: 1.4612 - acc_top1: 0.9816 - acc_top2: 0.9950 - 6ms/step\n",
"step 120/157 - loss: 1.4763 - acc_top1: 0.9832 - acc_top2: 0.9954 - 6ms/step\n",
"step 140/157 - loss: 1.4612 - acc_top1: 0.9849 - acc_top2: 0.9959 - 6ms/step\n",
"step 157/157 - loss: 1.4612 - acc_top1: 0.9844 - acc_top2: 0.9959 - 6ms/step\n",
"Eval samples: 10000\n"
]
},
{
"data": {
"text/plain": [
"{'loss': [1.4611504], 'acc_top1': 0.9838, 'acc_top2': 0.9966}"
"{'loss': [1.4611504], 'acc_top1': 0.9844, 'acc_top2': 0.9959}"
]
},
"execution_count": 17,
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册