未验证 提交 c888f539 编写于 作者: C Chen Long 提交者: GitHub

fix_lenet_docs test=develop (#886)

上级 3986db50
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 24, "execution_count": 35,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -46,7 +46,7 @@ ...@@ -46,7 +46,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 36,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -74,7 +74,7 @@ ...@@ -74,7 +74,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 37,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -117,7 +117,7 @@ ...@@ -117,7 +117,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 38,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -127,9 +127,9 @@ ...@@ -127,9 +127,9 @@
" def __init__(self):\n", " def __init__(self):\n",
" super(LeNet, self).__init__()\n", " super(LeNet, self).__init__()\n",
" self.conv1 = paddle.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)\n", " self.conv1 = paddle.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)\n",
" self.max_pool1 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", " self.max_pool1 = paddle.nn.MaxPool2d(kernel_size=2, stride=2)\n",
" self.conv2 = paddle.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)\n", " self.conv2 = paddle.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)\n",
" self.max_pool2 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", " self.max_pool2 = paddle.nn.MaxPool2d(kernel_size=2, stride=2)\n",
" self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)\n", " self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)\n",
" self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)\n", " self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)\n",
" self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)\n", " self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)\n",
...@@ -155,13 +155,141 @@ ...@@ -155,13 +155,141 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 训练方式一\n", "# 3.训练方式一\n",
"通过`Model` 构建实例,快速完成模型训练" "组网后,开始对模型进行训练,先构建`train_loader`,加载训练数据,然后定义`train`函数,设置好损失函数后,按batch加载数据,完成模型的训练。"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 21, "execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, batch_id: 0, loss is: [2.3064885], acc is: [0.109375]\n",
"epoch: 0, batch_id: 100, loss is: [1.5477252], acc is: [1.]\n",
"epoch: 0, batch_id: 200, loss is: [1.5201148], acc is: [1.]\n",
"epoch: 0, batch_id: 300, loss is: [1.525354], acc is: [0.953125]\n",
"epoch: 0, batch_id: 400, loss is: [1.5201038], acc is: [1.]\n",
"epoch: 0, batch_id: 500, loss is: [1.4901408], acc is: [1.]\n",
"epoch: 0, batch_id: 600, loss is: [1.4925538], acc is: [0.984375]\n",
"epoch: 0, batch_id: 700, loss is: [1.5247533], acc is: [0.96875]\n",
"epoch: 0, batch_id: 800, loss is: [1.5365943], acc is: [1.]\n",
"epoch: 0, batch_id: 900, loss is: [1.5154861], acc is: [0.984375]\n",
"epoch: 1, batch_id: 0, loss is: [1.4988302], acc is: [0.984375]\n",
"epoch: 1, batch_id: 100, loss is: [1.493154], acc is: [0.984375]\n",
"epoch: 1, batch_id: 200, loss is: [1.4974915], acc is: [1.]\n",
"epoch: 1, batch_id: 300, loss is: [1.5089471], acc is: [0.984375]\n",
"epoch: 1, batch_id: 400, loss is: [1.5041347], acc is: [1.]\n",
"epoch: 1, batch_id: 500, loss is: [1.5145375], acc is: [1.]\n",
"epoch: 1, batch_id: 600, loss is: [1.4904011], acc is: [0.984375]\n",
"epoch: 1, batch_id: 700, loss is: [1.5121607], acc is: [0.96875]\n",
"epoch: 1, batch_id: 800, loss is: [1.5078678], acc is: [1.]\n",
"epoch: 1, batch_id: 900, loss is: [1.500349], acc is: [0.984375]\n"
]
}
],
"source": [
"import paddle\n",
"train_loader = paddle.io.DataLoader(train_dataset, places=paddle.CPUPlace(), batch_size=64)\n",
"# 加载训练集 batch_size 设为 64\n",
"def train(model):\n",
" model.train()\n",
" epochs = 2\n",
" optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n",
" # 用Adam作为优化函数\n",
" for epoch in range(epochs):\n",
" for batch_id, data in enumerate(train_loader()):\n",
" x_data = data[0]\n",
" y_data = data[1]\n",
" predicts = model(x_data)\n",
" loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n",
" # 计算损失\n",
" acc = paddle.metric.accuracy(predicts, y_data, k=2)\n",
" avg_loss = paddle.mean(loss)\n",
" avg_acc = paddle.mean(acc)\n",
" avg_loss.backward()\n",
" if batch_id % 100 == 0:\n",
" print(\"epoch: {}, batch_id: {}, loss is: {}, acc is: {}\".format(epoch, batch_id, avg_loss.numpy(), avg_acc.numpy()))\n",
" optim.minimize(avg_loss)\n",
" model.clear_gradients()\n",
"model = LeNet()\n",
"train(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 对模型进行验证\n",
"训练完成后,需要验证模型的效果,此时,加载测试数据集,然后用训练好的模对测试集进行预测,计算损失与精度。"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"batch_id: 0, loss is: [1.4659549], acc is: [1.]\n",
"batch_id: 100, loss is: [1.4933192], acc is: [0.984375]\n",
"batch_id: 200, loss is: [1.4779761], acc is: [1.]\n",
"batch_id: 300, loss is: [1.4919193], acc is: [0.984375]\n",
"batch_id: 400, loss is: [1.5036212], acc is: [1.]\n",
"batch_id: 500, loss is: [1.4922347], acc is: [0.984375]\n",
"batch_id: 600, loss is: [1.4765416], acc is: [0.984375]\n",
"batch_id: 700, loss is: [1.4997746], acc is: [0.984375]\n",
"batch_id: 800, loss is: [1.4831288], acc is: [1.]\n",
"batch_id: 900, loss is: [1.498342], acc is: [0.984375]\n"
]
}
],
"source": [
"import paddle\n",
"test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64)\n",
"# 加载测试数据集\n",
"def test(model):\n",
" model.eval()\n",
" batch_size = 64\n",
" for batch_id, data in enumerate(train_loader()):\n",
" x_data = data[0]\n",
" y_data = data[1]\n",
" predicts = model(x_data)\n",
" # 获取预测结果\n",
" loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n",
" acc = paddle.metric.accuracy(predicts, y_data, k=2)\n",
" avg_loss = paddle.mean(loss)\n",
" avg_acc = paddle.mean(acc)\n",
" avg_loss.backward()\n",
" if batch_id % 100 == 0:\n",
" print(\"batch_id: {}, loss is: {}, acc is: {}\".format(batch_id, avg_loss.numpy(), avg_acc.numpy()))\n",
"test(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 训练方式一结束\n",
"以上就是训练方式一,通过这种方式,可以清楚的看到训练和测试中的每一步过程。但是,这种方式句法比较复杂。因此,我们提供了训练方式二,能够更加快速、高效的完成模型的训练与测试。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3.训练方式二\n",
"通过paddle提供的`Model` 构建实例,使用封装好的训练与测试接口,快速完成模型训练与测试。"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -170,10 +298,9 @@ ...@@ -170,10 +298,9 @@
"from paddle.metric import Accuracy\n", "from paddle.metric import Accuracy\n",
"inputs = InputSpec([None, 784], 'float32', 'x')\n", "inputs = InputSpec([None, 784], 'float32', 'x')\n",
"labels = InputSpec([None, 10], 'float32', 'x')\n", "labels = InputSpec([None, 10], 'float32', 'x')\n",
"model = paddle.hapi.Model(LeNet(), inputs, labels)\n", "model = paddle.Model(LeNet(), inputs, labels)\n",
"optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n", "optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n",
"\n", "\n",
"\n",
"model.prepare(\n", "model.prepare(\n",
" optim,\n", " optim,\n",
" paddle.nn.loss.CrossEntropyLoss(),\n", " paddle.nn.loss.CrossEntropyLoss(),\n",
...@@ -190,7 +317,7 @@ ...@@ -190,7 +317,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22, "execution_count": 42,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -198,266 +325,209 @@ ...@@ -198,266 +325,209 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Epoch 1/2\n", "Epoch 1/2\n",
"step 10/938 - loss: 2.2369 - acc_top1: 0.3281 - acc_top2: 0.4172 - 18ms/step\n" "step 10/938 - loss: 2.2434 - acc_top1: 0.1344 - acc_top2: 0.3719 - 14ms/step\n",
] "step 20/938 - loss: 2.0292 - acc_top1: 0.2836 - acc_top2: 0.4633 - 14ms/step\n",
}, "step 30/938 - loss: 1.9341 - acc_top1: 0.3755 - acc_top2: 0.5214 - 14ms/step\n",
{ "step 40/938 - loss: 1.8009 - acc_top1: 0.4469 - acc_top2: 0.5727 - 14ms/step\n",
"name": "stderr", "step 50/938 - loss: 1.8000 - acc_top1: 0.4975 - acc_top2: 0.6125 - 13ms/step\n",
"output_type": "stream", "step 60/938 - loss: 1.6335 - acc_top1: 0.5417 - acc_top2: 0.6438 - 14ms/step\n",
"text": [ "step 70/938 - loss: 1.7931 - acc_top1: 0.5708 - acc_top2: 0.6643 - 13ms/step\n",
"/Library/Python/3.7/site-packages/paddle/fluid/layers/utils.py:76: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n", "step 80/938 - loss: 1.6699 - acc_top1: 0.5961 - acc_top2: 0.6846 - 13ms/step\n",
" return (isinstance(seq, collections.Sequence) and\n" "step 90/938 - loss: 1.6832 - acc_top1: 0.6189 - acc_top2: 0.7069 - 13ms/step\n",
] "step 100/938 - loss: 1.6336 - acc_top1: 0.6409 - acc_top2: 0.7245 - 14ms/step\n",
}, "step 110/938 - loss: 1.6598 - acc_top1: 0.6557 - acc_top2: 0.7376 - 13ms/step\n",
{ "step 120/938 - loss: 1.6348 - acc_top1: 0.6708 - acc_top2: 0.7488 - 13ms/step\n",
"name": "stdout", "step 130/938 - loss: 1.6223 - acc_top1: 0.6851 - acc_top2: 0.7601 - 13ms/step\n",
"output_type": "stream", "step 140/938 - loss: 1.5622 - acc_top1: 0.6970 - acc_top2: 0.7694 - 13ms/step\n",
"text": [ "step 150/938 - loss: 1.6455 - acc_top1: 0.7065 - acc_top2: 0.7767 - 14ms/step\n",
"step 20/938 - loss: 2.0185 - acc_top1: 0.3656 - acc_top2: 0.4328 - 17ms/step\n", "step 160/938 - loss: 1.6127 - acc_top1: 0.7154 - acc_top2: 0.7837 - 14ms/step\n",
"step 30/938 - loss: 1.9579 - acc_top1: 0.4120 - acc_top2: 0.4969 - 16ms/step\n", "step 170/938 - loss: 1.5963 - acc_top1: 0.7242 - acc_top2: 0.7898 - 14ms/step\n",
"step 40/938 - loss: 1.8549 - acc_top1: 0.4602 - acc_top2: 0.5500 - 16ms/step\n", "step 180/938 - loss: 1.6485 - acc_top1: 0.7310 - acc_top2: 0.7948 - 14ms/step\n",
"step 50/938 - loss: 1.8628 - acc_top1: 0.5097 - acc_top2: 0.6028 - 16ms/step\n", "step 190/938 - loss: 1.6666 - acc_top1: 0.7368 - acc_top2: 0.7992 - 14ms/step\n",
"step 60/938 - loss: 1.7139 - acc_top1: 0.5456 - acc_top2: 0.6409 - 16ms/step\n", "step 200/938 - loss: 1.7862 - acc_top1: 0.7419 - acc_top2: 0.8030 - 14ms/step\n",
"step 70/938 - loss: 1.7296 - acc_top1: 0.5795 - acc_top2: 0.6719 - 15ms/step\n", "step 210/938 - loss: 1.5479 - acc_top1: 0.7464 - acc_top2: 0.8064 - 14ms/step\n",
"step 80/938 - loss: 1.6302 - acc_top1: 0.6053 - acc_top2: 0.6949 - 15ms/step\n", "step 220/938 - loss: 1.5650 - acc_top1: 0.7515 - acc_top2: 0.8106 - 14ms/step\n",
"step 90/938 - loss: 1.6688 - acc_top1: 0.6290 - acc_top2: 0.7158 - 15ms/step\n", "step 230/938 - loss: 1.5822 - acc_top1: 0.7562 - acc_top2: 0.8141 - 14ms/step\n",
"step 100/938 - loss: 1.6401 - acc_top1: 0.6491 - acc_top2: 0.7327 - 15ms/step\n", "step 240/938 - loss: 1.5966 - acc_top1: 0.7608 - acc_top2: 0.8179 - 14ms/step\n",
"step 110/938 - loss: 1.6357 - acc_top1: 0.6636 - acc_top2: 0.7440 - 15ms/step\n", "step 250/938 - loss: 1.5551 - acc_top1: 0.7650 - acc_top2: 0.8213 - 14ms/step\n",
"step 120/938 - loss: 1.6309 - acc_top1: 0.6767 - acc_top2: 0.7539 - 15ms/step\n", "step 260/938 - loss: 1.5584 - acc_top1: 0.7699 - acc_top2: 0.8249 - 14ms/step\n",
"step 130/938 - loss: 1.6445 - acc_top1: 0.6894 - acc_top2: 0.7638 - 15ms/step\n", "step 270/938 - loss: 1.5933 - acc_top1: 0.7730 - acc_top2: 0.8273 - 14ms/step\n",
"step 140/938 - loss: 1.5961 - acc_top1: 0.7002 - acc_top2: 0.7728 - 15ms/step\n", "step 280/938 - loss: 1.5589 - acc_top1: 0.7769 - acc_top2: 0.8301 - 14ms/step\n",
"step 150/938 - loss: 1.6822 - acc_top1: 0.7086 - acc_top2: 0.7794 - 15ms/step\n", "step 290/938 - loss: 1.6513 - acc_top1: 0.7793 - acc_top2: 0.8315 - 14ms/step\n",
"step 160/938 - loss: 1.6243 - acc_top1: 0.7176 - acc_top2: 0.7858 - 15ms/step\n", "step 300/938 - loss: 1.5929 - acc_top1: 0.7821 - acc_top2: 0.8337 - 14ms/step\n",
"step 170/938 - loss: 1.6159 - acc_top1: 0.7254 - acc_top2: 0.7915 - 15ms/step\n", "step 310/938 - loss: 1.5672 - acc_top1: 0.7849 - acc_top2: 0.8360 - 14ms/step\n",
"step 180/938 - loss: 1.6820 - acc_top1: 0.7312 - acc_top2: 0.7962 - 15ms/step\n", "step 320/938 - loss: 1.5147 - acc_top1: 0.7879 - acc_top2: 0.8381 - 14ms/step\n",
"step 190/938 - loss: 1.6733 - acc_top1: 0.7363 - acc_top2: 0.7999 - 15ms/step\n", "step 330/938 - loss: 1.5697 - acc_top1: 0.7902 - acc_top2: 0.8397 - 14ms/step\n",
"step 200/938 - loss: 1.7717 - acc_top1: 0.7413 - acc_top2: 0.8039 - 15ms/step\n", "step 340/938 - loss: 1.5697 - acc_top1: 0.7919 - acc_top2: 0.8406 - 14ms/step\n",
"step 210/938 - loss: 1.5468 - acc_top1: 0.7458 - acc_top2: 0.8072 - 15ms/step\n", "step 350/938 - loss: 1.6122 - acc_top1: 0.7941 - acc_top2: 0.8423 - 14ms/step\n",
"step 220/938 - loss: 1.5654 - acc_top1: 0.7506 - acc_top2: 0.8111 - 15ms/step\n", "step 360/938 - loss: 1.5934 - acc_top1: 0.7960 - acc_top2: 0.8435 - 14ms/step\n",
"step 230/938 - loss: 1.6129 - acc_top1: 0.7547 - acc_top2: 0.8143 - 15ms/step\n", "step 370/938 - loss: 1.6258 - acc_top1: 0.7982 - acc_top2: 0.8451 - 14ms/step\n",
"step 240/938 - loss: 1.5937 - acc_top1: 0.7592 - acc_top2: 0.8180 - 15ms/step\n", "step 380/938 - loss: 1.6805 - acc_top1: 0.7996 - acc_top2: 0.8463 - 14ms/step\n",
"step 250/938 - loss: 1.5457 - acc_top1: 0.7631 - acc_top2: 0.8214 - 15ms/step\n", "step 390/938 - loss: 1.5997 - acc_top1: 0.8011 - acc_top2: 0.8475 - 14ms/step\n",
"step 260/938 - loss: 1.6041 - acc_top1: 0.7673 - acc_top2: 0.8249 - 15ms/step\n", "step 400/938 - loss: 1.6151 - acc_top1: 0.8029 - acc_top2: 0.8488 - 14ms/step\n",
"step 270/938 - loss: 1.6049 - acc_top1: 0.7700 - acc_top2: 0.8271 - 15ms/step\n", "step 410/938 - loss: 1.5800 - acc_top1: 0.8047 - acc_top2: 0.8499 - 14ms/step\n",
"step 280/938 - loss: 1.5989 - acc_top1: 0.7735 - acc_top2: 0.8299 - 15ms/step\n", "step 420/938 - loss: 1.5950 - acc_top1: 0.8060 - acc_top2: 0.8508 - 14ms/step\n",
"step 290/938 - loss: 1.6950 - acc_top1: 0.7752 - acc_top2: 0.8310 - 15ms/step\n", "step 430/938 - loss: 1.5533 - acc_top1: 0.8075 - acc_top2: 0.8517 - 14ms/step\n",
"step 300/938 - loss: 1.5888 - acc_top1: 0.7781 - acc_top2: 0.8330 - 15ms/step\n", "step 440/938 - loss: 1.6171 - acc_top1: 0.8086 - acc_top2: 0.8521 - 14ms/step\n",
"step 310/938 - loss: 1.5983 - acc_top1: 0.7808 - acc_top2: 0.8350 - 15ms/step\n", "step 450/938 - loss: 1.5756 - acc_top1: 0.8103 - acc_top2: 0.8533 - 14ms/step\n",
"step 320/938 - loss: 1.5133 - acc_top1: 0.7840 - acc_top2: 0.8370 - 15ms/step\n", "step 460/938 - loss: 1.5655 - acc_top1: 0.8121 - acc_top2: 0.8544 - 14ms/step\n",
"step 330/938 - loss: 1.5587 - acc_top1: 0.7866 - acc_top2: 0.8385 - 15ms/step\n", "step 470/938 - loss: 1.5816 - acc_top1: 0.8139 - acc_top2: 0.8555 - 14ms/step\n",
"step 340/938 - loss: 1.6093 - acc_top1: 0.7882 - acc_top2: 0.8393 - 15ms/step\n", "step 480/938 - loss: 1.6202 - acc_top1: 0.8148 - acc_top2: 0.8562 - 14ms/step\n",
"step 350/938 - loss: 1.6259 - acc_top1: 0.7902 - acc_top2: 0.8410 - 15ms/step\n", "step 490/938 - loss: 1.6223 - acc_top1: 0.8157 - acc_top2: 0.8567 - 14ms/step\n",
"step 360/938 - loss: 1.6194 - acc_top1: 0.7918 - acc_top2: 0.8422 - 15ms/step\n", "step 500/938 - loss: 1.5198 - acc_top1: 0.8167 - acc_top2: 0.8574 - 14ms/step\n",
"step 370/938 - loss: 1.6531 - acc_top1: 0.7941 - acc_top2: 0.8438 - 15ms/step\n", "step 510/938 - loss: 1.5853 - acc_top1: 0.8181 - acc_top2: 0.8583 - 14ms/step\n",
"step 380/938 - loss: 1.6986 - acc_top1: 0.7957 - acc_top2: 0.8447 - 15ms/step\n", "step 520/938 - loss: 1.5252 - acc_top1: 0.8196 - acc_top2: 0.8593 - 14ms/step\n",
"step 390/938 - loss: 1.5932 - acc_top1: 0.7974 - acc_top2: 0.8459 - 15ms/step\n", "step 530/938 - loss: 1.5265 - acc_top1: 0.8207 - acc_top2: 0.8601 - 14ms/step\n",
"step 400/938 - loss: 1.6512 - acc_top1: 0.7993 - acc_top2: 0.8474 - 15ms/step\n", "step 540/938 - loss: 1.5297 - acc_top1: 0.8217 - acc_top2: 0.8608 - 14ms/step\n",
"step 410/938 - loss: 1.5698 - acc_top1: 0.8012 - acc_top2: 0.8487 - 15ms/step\n", "step 550/938 - loss: 1.5743 - acc_top1: 0.8226 - acc_top2: 0.8613 - 13ms/step\n",
"step 420/938 - loss: 1.5889 - acc_top1: 0.8025 - acc_top2: 0.8494 - 15ms/step\n", "step 560/938 - loss: 1.6419 - acc_top1: 0.8237 - acc_top2: 0.8622 - 13ms/step\n",
"step 430/938 - loss: 1.5518 - acc_top1: 0.8036 - acc_top2: 0.8503 - 15ms/step\n", "step 570/938 - loss: 1.5556 - acc_top1: 0.8247 - acc_top2: 0.8630 - 13ms/step\n",
"step 440/938 - loss: 1.6057 - acc_top1: 0.8048 - acc_top2: 0.8508 - 15ms/step\n", "step 580/938 - loss: 1.5349 - acc_top1: 0.8254 - acc_top2: 0.8635 - 13ms/step\n",
"step 450/938 - loss: 1.6081 - acc_top1: 0.8064 - acc_top2: 0.8519 - 15ms/step\n", "step 590/938 - loss: 1.4915 - acc_top1: 0.8263 - acc_top2: 0.8640 - 13ms/step\n",
"step 460/938 - loss: 1.5742 - acc_top1: 0.8079 - acc_top2: 0.8531 - 15ms/step\n", "step 600/938 - loss: 1.5672 - acc_top1: 0.8277 - acc_top2: 0.8651 - 13ms/step\n",
"step 470/938 - loss: 1.5704 - acc_top1: 0.8095 - acc_top2: 0.8543 - 15ms/step\n", "step 610/938 - loss: 1.5464 - acc_top1: 0.8288 - acc_top2: 0.8659 - 13ms/step\n",
"step 480/938 - loss: 1.6083 - acc_top1: 0.8110 - acc_top2: 0.8550 - 15ms/step\n", "step 620/938 - loss: 1.6329 - acc_top1: 0.8292 - acc_top2: 0.8661 - 13ms/step\n",
"step 490/938 - loss: 1.6081 - acc_top1: 0.8120 - acc_top2: 0.8555 - 15ms/step\n", "step 630/938 - loss: 1.6121 - acc_top1: 0.8296 - acc_top2: 0.8662 - 13ms/step\n",
"step 500/938 - loss: 1.5156 - acc_top1: 0.8133 - acc_top2: 0.8564 - 15ms/step\n", "step 640/938 - loss: 1.5636 - acc_top1: 0.8305 - acc_top2: 0.8668 - 13ms/step\n",
"step 510/938 - loss: 1.5856 - acc_top1: 0.8148 - acc_top2: 0.8573 - 15ms/step\n", "step 650/938 - loss: 1.6227 - acc_top1: 0.8311 - acc_top2: 0.8672 - 13ms/step\n",
"step 520/938 - loss: 1.5275 - acc_top1: 0.8163 - acc_top2: 0.8582 - 15ms/step\n", "step 660/938 - loss: 1.5646 - acc_top1: 0.8319 - acc_top2: 0.8678 - 13ms/step\n",
"step 530/938 - loss: 1.5345 - acc_top1: 0.8172 - acc_top2: 0.8591 - 15ms/step\n", "step 670/938 - loss: 1.5620 - acc_top1: 0.8325 - acc_top2: 0.8681 - 13ms/step\n",
"step 540/938 - loss: 1.5387 - acc_top1: 0.8181 - acc_top2: 0.8596 - 15ms/step\n", "step 680/938 - loss: 1.4908 - acc_top1: 0.8333 - acc_top2: 0.8688 - 13ms/step\n",
"step 550/938 - loss: 1.5753 - acc_top1: 0.8190 - acc_top2: 0.8601 - 15ms/step\n", "step 690/938 - loss: 1.6010 - acc_top1: 0.8339 - acc_top2: 0.8691 - 13ms/step\n",
"step 560/938 - loss: 1.6103 - acc_top1: 0.8203 - acc_top2: 0.8610 - 15ms/step\n", "step 700/938 - loss: 1.5592 - acc_top1: 0.8346 - acc_top2: 0.8695 - 13ms/step\n",
"step 570/938 - loss: 1.5571 - acc_top1: 0.8215 - acc_top2: 0.8618 - 15ms/step\n", "step 710/938 - loss: 1.6226 - acc_top1: 0.8352 - acc_top2: 0.8699 - 13ms/step\n",
"step 580/938 - loss: 1.5575 - acc_top1: 0.8221 - acc_top2: 0.8622 - 15ms/step\n", "step 720/938 - loss: 1.5642 - acc_top1: 0.8362 - acc_top2: 0.8705 - 13ms/step\n",
"step 590/938 - loss: 1.4821 - acc_top1: 0.8230 - acc_top2: 0.8627 - 15ms/step\n", "step 730/938 - loss: 1.5807 - acc_top1: 0.8367 - acc_top2: 0.8707 - 13ms/step\n",
"step 600/938 - loss: 1.5644 - acc_top1: 0.8243 - acc_top2: 0.8636 - 15ms/step\n", "step 740/938 - loss: 1.5721 - acc_top1: 0.8371 - acc_top2: 0.8708 - 13ms/step\n",
"step 610/938 - loss: 1.5317 - acc_top1: 0.8253 - acc_top2: 0.8644 - 15ms/step\n", "step 750/938 - loss: 1.6542 - acc_top1: 0.8377 - acc_top2: 0.8711 - 13ms/step\n",
"step 620/938 - loss: 1.5849 - acc_top1: 0.8258 - acc_top2: 0.8647 - 15ms/step\n", "step 760/938 - loss: 1.5128 - acc_top1: 0.8385 - acc_top2: 0.8716 - 13ms/step\n",
"step 630/938 - loss: 1.6087 - acc_top1: 0.8263 - acc_top2: 0.8649 - 15ms/step\n", "step 770/938 - loss: 1.5711 - acc_top1: 0.8391 - acc_top2: 0.8721 - 14ms/step\n",
"step 640/938 - loss: 1.5617 - acc_top1: 0.8272 - acc_top2: 0.8655 - 15ms/step\n", "step 780/938 - loss: 1.6095 - acc_top1: 0.8395 - acc_top2: 0.8725 - 14ms/step\n",
"step 650/938 - loss: 1.6376 - acc_top1: 0.8279 - acc_top2: 0.8660 - 15ms/step\n", "step 790/938 - loss: 1.5348 - acc_top1: 0.8402 - acc_top2: 0.8730 - 14ms/step\n",
"step 660/938 - loss: 1.5428 - acc_top1: 0.8287 - acc_top2: 0.8665 - 15ms/step\n", "step 800/938 - loss: 1.5715 - acc_top1: 0.8407 - acc_top2: 0.8732 - 14ms/step\n",
"step 670/938 - loss: 1.5797 - acc_top1: 0.8293 - acc_top2: 0.8668 - 15ms/step\n", "step 810/938 - loss: 1.5880 - acc_top1: 0.8413 - acc_top2: 0.8737 - 14ms/step\n",
"step 680/938 - loss: 1.5210 - acc_top1: 0.8300 - acc_top2: 0.8674 - 15ms/step\n", "step 820/938 - loss: 1.6160 - acc_top1: 0.8418 - acc_top2: 0.8740 - 14ms/step\n",
"step 690/938 - loss: 1.6159 - acc_top1: 0.8305 - acc_top2: 0.8677 - 15ms/step\n", "step 830/938 - loss: 1.5585 - acc_top1: 0.8426 - acc_top2: 0.8746 - 14ms/step\n",
"step 700/938 - loss: 1.5592 - acc_top1: 0.8313 - acc_top2: 0.8682 - 15ms/step\n", "step 840/938 - loss: 1.5829 - acc_top1: 0.8429 - acc_top2: 0.8748 - 14ms/step\n",
"step 710/938 - loss: 1.6400 - acc_top1: 0.8318 - acc_top2: 0.8685 - 15ms/step\n", "step 850/938 - loss: 1.5348 - acc_top1: 0.8435 - acc_top2: 0.8753 - 14ms/step\n",
"step 720/938 - loss: 1.5638 - acc_top1: 0.8327 - acc_top2: 0.8691 - 15ms/step\n", "step 860/938 - loss: 1.5448 - acc_top1: 0.8438 - acc_top2: 0.8754 - 14ms/step\n",
"step 730/938 - loss: 1.5691 - acc_top1: 0.8333 - acc_top2: 0.8693 - 15ms/step\n", "step 870/938 - loss: 1.5463 - acc_top1: 0.8443 - acc_top2: 0.8759 - 14ms/step\n",
"step 740/938 - loss: 1.5848 - acc_top1: 0.8337 - acc_top2: 0.8695 - 15ms/step\n", "step 880/938 - loss: 1.5763 - acc_top1: 0.8449 - acc_top2: 0.8762 - 14ms/step\n",
"step 750/938 - loss: 1.6317 - acc_top1: 0.8344 - acc_top2: 0.8698 - 15ms/step\n", "step 890/938 - loss: 1.5699 - acc_top1: 0.8453 - acc_top2: 0.8764 - 14ms/step\n",
"step 760/938 - loss: 1.5127 - acc_top1: 0.8352 - acc_top2: 0.8703 - 15ms/step\n", "step 900/938 - loss: 1.5616 - acc_top1: 0.8456 - acc_top2: 0.8766 - 14ms/step\n",
"step 770/938 - loss: 1.5822 - acc_top1: 0.8359 - acc_top2: 0.8707 - 15ms/step\n", "step 910/938 - loss: 1.5026 - acc_top1: 0.8461 - acc_top2: 0.8771 - 14ms/step\n",
"step 780/938 - loss: 1.6010 - acc_top1: 0.8366 - acc_top2: 0.8712 - 15ms/step\n", "step 920/938 - loss: 1.5380 - acc_top1: 0.8467 - acc_top2: 0.8774 - 14ms/step\n",
"step 790/938 - loss: 1.5238 - acc_top1: 0.8373 - acc_top2: 0.8717 - 15ms/step\n", "step 930/938 - loss: 1.5993 - acc_top1: 0.8470 - acc_top2: 0.8777 - 14ms/step\n",
"step 800/938 - loss: 1.5858 - acc_top1: 0.8377 - acc_top2: 0.8719 - 15ms/step\n", "step 938/938 - loss: 1.4942 - acc_top1: 0.8473 - acc_top2: 0.8778 - 14ms/step\n",
"step 810/938 - loss: 1.5800 - acc_top1: 0.8384 - acc_top2: 0.8724 - 15ms/step\n",
"step 820/938 - loss: 1.6312 - acc_top1: 0.8390 - acc_top2: 0.8727 - 15ms/step\n",
"step 830/938 - loss: 1.5812 - acc_top1: 0.8398 - acc_top2: 0.8732 - 15ms/step\n",
"step 840/938 - loss: 1.5661 - acc_top1: 0.8402 - acc_top2: 0.8734 - 15ms/step\n",
"step 850/938 - loss: 1.5379 - acc_top1: 0.8409 - acc_top2: 0.8739 - 15ms/step\n",
"step 860/938 - loss: 1.5266 - acc_top1: 0.8413 - acc_top2: 0.8740 - 15ms/step\n",
"step 870/938 - loss: 1.5264 - acc_top1: 0.8420 - acc_top2: 0.8745 - 15ms/step\n",
"step 880/938 - loss: 1.5688 - acc_top1: 0.8425 - acc_top2: 0.8748 - 15ms/step\n",
"step 890/938 - loss: 1.5707 - acc_top1: 0.8429 - acc_top2: 0.8751 - 15ms/step\n",
"step 900/938 - loss: 1.5564 - acc_top1: 0.8432 - acc_top2: 0.8752 - 15ms/step\n",
"step 910/938 - loss: 1.4924 - acc_top1: 0.8438 - acc_top2: 0.8757 - 15ms/step\n",
"step 920/938 - loss: 1.5514 - acc_top1: 0.8443 - acc_top2: 0.8760 - 15ms/step\n",
"step 930/938 - loss: 1.5850 - acc_top1: 0.8446 - acc_top2: 0.8762 - 15ms/step\n",
"step 938/938 - loss: 1.4915 - acc_top1: 0.8448 - acc_top2: 0.8764 - 15ms/step\n",
"save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/0\n", "save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/0\n",
"Eval begin...\n",
"step 10/157 - loss: 1.5984 - acc_top1: 0.8797 - acc_top2: 0.8953 - 5ms/step\n",
"step 20/157 - loss: 1.6266 - acc_top1: 0.8789 - acc_top2: 0.9000 - 5ms/step\n",
"step 30/157 - loss: 1.6475 - acc_top1: 0.8771 - acc_top2: 0.8984 - 5ms/step\n",
"step 40/157 - loss: 1.6329 - acc_top1: 0.8730 - acc_top2: 0.8957 - 5ms/step\n",
"step 50/157 - loss: 1.5399 - acc_top1: 0.8712 - acc_top2: 0.8934 - 5ms/step\n",
"step 60/157 - loss: 1.6322 - acc_top1: 0.8750 - acc_top2: 0.8961 - 5ms/step\n",
"step 70/157 - loss: 1.5818 - acc_top1: 0.8721 - acc_top2: 0.8931 - 5ms/step\n",
"step 80/157 - loss: 1.5522 - acc_top1: 0.8760 - acc_top2: 0.8979 - 5ms/step\n",
"step 90/157 - loss: 1.6085 - acc_top1: 0.8785 - acc_top2: 0.8984 - 5ms/step\n",
"step 100/157 - loss: 1.5661 - acc_top1: 0.8784 - acc_top2: 0.8980 - 5ms/step\n",
"step 110/157 - loss: 1.5694 - acc_top1: 0.8805 - acc_top2: 0.8996 - 5ms/step\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step 120/157 - loss: 1.6012 - acc_top1: 0.8824 - acc_top2: 0.9003 - 5ms/step\n",
"step 130/157 - loss: 1.5378 - acc_top1: 0.8844 - acc_top2: 0.9017 - 5ms/step\n",
"step 140/157 - loss: 1.5068 - acc_top1: 0.8858 - acc_top2: 0.9022 - 5ms/step\n",
"step 150/157 - loss: 1.5424 - acc_top1: 0.8873 - acc_top2: 0.9029 - 5ms/step\n",
"step 157/157 - loss: 1.5862 - acc_top1: 0.8872 - acc_top2: 0.9035 - 5ms/step\n",
"Eval samples: 10000\n",
"Epoch 2/2\n", "Epoch 2/2\n",
"step 10/938 - loss: 1.5988 - acc_top1: 0.8859 - acc_top2: 0.9016 - 15ms/step\n", "step 10/938 - loss: 1.5919 - acc_top1: 0.8875 - acc_top2: 0.9047 - 14ms/step\n",
"step 20/938 - loss: 1.5702 - acc_top1: 0.8852 - acc_top2: 0.9047 - 15ms/step\n", "step 20/938 - loss: 1.5900 - acc_top1: 0.8875 - acc_top2: 0.9062 - 14ms/step\n",
"step 30/938 - loss: 1.5999 - acc_top1: 0.8833 - acc_top2: 0.9021 - 15ms/step\n", "step 30/938 - loss: 1.5929 - acc_top1: 0.8891 - acc_top2: 0.9036 - 13ms/step\n",
"step 40/938 - loss: 1.5652 - acc_top1: 0.8816 - acc_top2: 0.9000 - 15ms/step\n", "step 40/938 - loss: 1.5855 - acc_top1: 0.8883 - acc_top2: 0.9027 - 13ms/step\n",
"step 50/938 - loss: 1.6163 - acc_top1: 0.8853 - acc_top2: 0.9047 - 15ms/step\n", "step 50/938 - loss: 1.6197 - acc_top1: 0.8916 - acc_top2: 0.9072 - 13ms/step\n",
"step 60/938 - loss: 1.5307 - acc_top1: 0.8849 - acc_top2: 0.9049 - 15ms/step\n", "step 60/938 - loss: 1.5084 - acc_top1: 0.8914 - acc_top2: 0.9078 - 13ms/step\n",
"step 70/938 - loss: 1.5542 - acc_top1: 0.8846 - acc_top2: 0.9029 - 15ms/step\n", "step 70/938 - loss: 1.5552 - acc_top1: 0.8904 - acc_top2: 0.9067 - 13ms/step\n",
"step 80/938 - loss: 1.5694 - acc_top1: 0.8816 - acc_top2: 0.9008 - 15ms/step\n", "step 80/938 - loss: 1.5700 - acc_top1: 0.8887 - acc_top2: 0.9049 - 13ms/step\n",
"step 90/938 - loss: 1.6030 - acc_top1: 0.8806 - acc_top2: 0.8991 - 15ms/step\n", "step 90/938 - loss: 1.6073 - acc_top1: 0.8866 - acc_top2: 0.9030 - 13ms/step\n",
"step 100/938 - loss: 1.5631 - acc_top1: 0.8814 - acc_top2: 0.8989 - 15ms/step\n", "step 100/938 - loss: 1.5754 - acc_top1: 0.8859 - acc_top2: 0.9022 - 13ms/step\n"
"step 110/938 - loss: 1.5598 - acc_top1: 0.8804 - acc_top2: 0.8984 - 15ms/step\n",
"step 120/938 - loss: 1.5773 - acc_top1: 0.8803 - acc_top2: 0.8986 - 15ms/step\n",
"step 130/938 - loss: 1.5076 - acc_top1: 0.8815 - acc_top2: 0.8995 - 15ms/step\n",
"step 140/938 - loss: 1.6064 - acc_top1: 0.8809 - acc_top2: 0.8988 - 15ms/step\n",
"step 150/938 - loss: 1.5279 - acc_top1: 0.8815 - acc_top2: 0.8993 - 15ms/step\n",
"step 160/938 - loss: 1.6039 - acc_top1: 0.8820 - acc_top2: 0.8998 - 15ms/step\n",
"step 170/938 - loss: 1.5709 - acc_top1: 0.8814 - acc_top2: 0.8993 - 15ms/step\n",
"step 180/938 - loss: 1.6164 - acc_top1: 0.8806 - acc_top2: 0.8985 - 15ms/step\n",
"step 190/938 - loss: 1.5920 - acc_top1: 0.8802 - acc_top2: 0.8985 - 15ms/step\n",
"step 200/938 - loss: 1.6457 - acc_top1: 0.8793 - acc_top2: 0.8973 - 15ms/step\n",
"step 210/938 - loss: 1.6045 - acc_top1: 0.8794 - acc_top2: 0.8977 - 15ms/step\n",
"step 220/938 - loss: 1.6614 - acc_top1: 0.8795 - acc_top2: 0.8975 - 15ms/step\n",
"step 230/938 - loss: 1.5384 - acc_top1: 0.8789 - acc_top2: 0.8966 - 15ms/step\n",
"step 240/938 - loss: 1.5556 - acc_top1: 0.8785 - acc_top2: 0.8960 - 15ms/step\n",
"step 250/938 - loss: 1.6006 - acc_top1: 0.8782 - acc_top2: 0.8961 - 15ms/step\n",
"step 260/938 - loss: 1.5552 - acc_top1: 0.8790 - acc_top2: 0.8968 - 15ms/step\n",
"step 270/938 - loss: 1.5805 - acc_top1: 0.8791 - acc_top2: 0.8970 - 15ms/step\n",
"step 280/938 - loss: 1.5404 - acc_top1: 0.8787 - acc_top2: 0.8966 - 15ms/step\n",
"step 290/938 - loss: 1.6023 - acc_top1: 0.8789 - acc_top2: 0.8969 - 15ms/step\n",
"step 300/938 - loss: 1.5706 - acc_top1: 0.8788 - acc_top2: 0.8969 - 15ms/step\n",
"step 310/938 - loss: 1.5424 - acc_top1: 0.8790 - acc_top2: 0.8968 - 15ms/step\n",
"step 320/938 - loss: 1.5823 - acc_top1: 0.8798 - acc_top2: 0.8975 - 15ms/step\n",
"step 330/938 - loss: 1.5600 - acc_top1: 0.8801 - acc_top2: 0.8977 - 15ms/step\n",
"step 340/938 - loss: 1.6258 - acc_top1: 0.8795 - acc_top2: 0.8970 - 15ms/step\n",
"step 350/938 - loss: 1.5093 - acc_top1: 0.8796 - acc_top2: 0.8972 - 15ms/step\n",
"step 360/938 - loss: 1.6030 - acc_top1: 0.8794 - acc_top2: 0.8967 - 15ms/step\n",
"step 370/938 - loss: 1.5732 - acc_top1: 0.8795 - acc_top2: 0.8969 - 15ms/step\n",
"step 380/938 - loss: 1.5980 - acc_top1: 0.8797 - acc_top2: 0.8972 - 15ms/step\n",
"step 390/938 - loss: 1.5902 - acc_top1: 0.8800 - acc_top2: 0.8974 - 15ms/step\n",
"step 400/938 - loss: 1.5395 - acc_top1: 0.8809 - acc_top2: 0.8983 - 15ms/step\n",
"step 410/938 - loss: 1.6623 - acc_top1: 0.8804 - acc_top2: 0.8978 - 15ms/step\n",
"step 420/938 - loss: 1.4987 - acc_top1: 0.8810 - acc_top2: 0.8983 - 15ms/step\n",
"step 430/938 - loss: 1.5989 - acc_top1: 0.8811 - acc_top2: 0.8983 - 15ms/step\n",
"step 440/938 - loss: 1.5722 - acc_top1: 0.8813 - acc_top2: 0.8984 - 15ms/step\n",
"step 450/938 - loss: 1.5549 - acc_top1: 0.8818 - acc_top2: 0.8986 - 15ms/step\n",
"step 460/938 - loss: 1.5536 - acc_top1: 0.8819 - acc_top2: 0.8986 - 15ms/step\n",
"step 470/938 - loss: 1.5247 - acc_top1: 0.8826 - acc_top2: 0.8992 - 15ms/step\n",
"step 480/938 - loss: 1.5520 - acc_top1: 0.8830 - acc_top2: 0.8995 - 15ms/step\n",
"step 490/938 - loss: 1.5518 - acc_top1: 0.8835 - acc_top2: 0.8998 - 15ms/step\n",
"step 500/938 - loss: 1.5227 - acc_top1: 0.8837 - acc_top2: 0.9000 - 15ms/step\n",
"step 510/938 - loss: 1.6014 - acc_top1: 0.8835 - acc_top2: 0.8998 - 15ms/step\n",
"step 520/938 - loss: 1.5526 - acc_top1: 0.8834 - acc_top2: 0.8998 - 15ms/step\n",
"step 530/938 - loss: 1.5849 - acc_top1: 0.8838 - acc_top2: 0.9001 - 15ms/step\n",
"step 540/938 - loss: 1.5607 - acc_top1: 0.8840 - acc_top2: 0.9006 - 15ms/step\n",
"step 550/938 - loss: 1.6438 - acc_top1: 0.8843 - acc_top2: 0.9010 - 15ms/step\n",
"step 560/938 - loss: 1.5229 - acc_top1: 0.8848 - acc_top2: 0.9014 - 15ms/step\n",
"step 570/938 - loss: 1.5395 - acc_top1: 0.8846 - acc_top2: 0.9012 - 15ms/step\n",
"step 580/938 - loss: 1.5409 - acc_top1: 0.8848 - acc_top2: 0.9013 - 15ms/step\n",
"step 590/938 - loss: 1.5851 - acc_top1: 0.8848 - acc_top2: 0.9013 - 15ms/step\n",
"step 600/938 - loss: 1.5383 - acc_top1: 0.8849 - acc_top2: 0.9013 - 15ms/step\n",
"step 610/938 - loss: 1.5969 - acc_top1: 0.8853 - acc_top2: 0.9016 - 15ms/step\n",
"step 620/938 - loss: 1.5634 - acc_top1: 0.8854 - acc_top2: 0.9017 - 15ms/step\n",
"step 630/938 - loss: 1.6308 - acc_top1: 0.8857 - acc_top2: 0.9019 - 15ms/step\n",
"step 640/938 - loss: 1.6413 - acc_top1: 0.8859 - acc_top2: 0.9021 - 15ms/step\n",
"step 650/938 - loss: 1.5954 - acc_top1: 0.8856 - acc_top2: 0.9020 - 15ms/step\n",
"step 660/938 - loss: 1.5278 - acc_top1: 0.8859 - acc_top2: 0.9023 - 15ms/step\n",
"step 670/938 - loss: 1.5144 - acc_top1: 0.8869 - acc_top2: 0.9035 - 15ms/step\n",
"step 680/938 - loss: 1.4612 - acc_top1: 0.8879 - acc_top2: 0.9048 - 15ms/step\n",
"step 690/938 - loss: 1.4820 - acc_top1: 0.8891 - acc_top2: 0.9060 - 15ms/step\n",
"step 700/938 - loss: 1.4766 - acc_top1: 0.8901 - acc_top2: 0.9073 - 15ms/step\n",
"step 710/938 - loss: 1.5245 - acc_top1: 0.8911 - acc_top2: 0.9083 - 15ms/step\n",
"step 720/938 - loss: 1.5183 - acc_top1: 0.8922 - acc_top2: 0.9095 - 15ms/step\n",
"step 730/938 - loss: 1.4971 - acc_top1: 0.8932 - acc_top2: 0.9106 - 15ms/step\n",
"step 740/938 - loss: 1.4744 - acc_top1: 0.8944 - acc_top2: 0.9117 - 15ms/step\n",
"step 750/938 - loss: 1.4789 - acc_top1: 0.8952 - acc_top2: 0.9127 - 15ms/step\n",
"step 760/938 - loss: 1.5114 - acc_top1: 0.8959 - acc_top2: 0.9137 - 15ms/step\n",
"step 770/938 - loss: 1.5035 - acc_top1: 0.8970 - acc_top2: 0.9147 - 15ms/step\n",
"step 780/938 - loss: 1.4668 - acc_top1: 0.8978 - acc_top2: 0.9157 - 15ms/step\n",
"step 790/938 - loss: 1.4850 - acc_top1: 0.8986 - acc_top2: 0.9166 - 15ms/step\n",
"step 800/938 - loss: 1.4777 - acc_top1: 0.8996 - acc_top2: 0.9176 - 15ms/step\n",
"step 810/938 - loss: 1.4783 - acc_top1: 0.9005 - acc_top2: 0.9186 - 15ms/step\n",
"step 820/938 - loss: 1.5256 - acc_top1: 0.9011 - acc_top2: 0.9194 - 15ms/step\n",
"step 830/938 - loss: 1.4801 - acc_top1: 0.9019 - acc_top2: 0.9202 - 15ms/step\n",
"step 840/938 - loss: 1.4873 - acc_top1: 0.9026 - acc_top2: 0.9211 - 15ms/step\n",
"step 850/938 - loss: 1.5093 - acc_top1: 0.9034 - acc_top2: 0.9219 - 15ms/step\n",
"step 860/938 - loss: 1.4727 - acc_top1: 0.9042 - acc_top2: 0.9227 - 15ms/step\n",
"step 870/938 - loss: 1.4917 - acc_top1: 0.9050 - acc_top2: 0.9235 - 15ms/step\n",
"step 880/938 - loss: 1.4792 - acc_top1: 0.9058 - acc_top2: 0.9243 - 15ms/step\n",
"step 890/938 - loss: 1.4854 - acc_top1: 0.9066 - acc_top2: 0.9251 - 15ms/step\n",
"step 900/938 - loss: 1.4616 - acc_top1: 0.9074 - acc_top2: 0.9258 - 15ms/step\n",
"step 910/938 - loss: 1.4954 - acc_top1: 0.9081 - acc_top2: 0.9265 - 15ms/step\n",
"step 920/938 - loss: 1.4875 - acc_top1: 0.9087 - acc_top2: 0.9272 - 15ms/step\n",
"step 930/938 - loss: 1.5037 - acc_top1: 0.9094 - acc_top2: 0.9279 - 15ms/step\n",
"step 938/938 - loss: 1.4964 - acc_top1: 0.9099 - acc_top2: 0.9284 - 15ms/step\n",
"save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/1\n",
"Eval begin...\n",
"step 10/157 - loss: 1.5196 - acc_top1: 0.9719 - acc_top2: 0.9969 - 5ms/step\n",
"step 20/157 - loss: 1.5393 - acc_top1: 0.9672 - acc_top2: 0.9945 - 6ms/step\n",
"step 30/157 - loss: 1.4928 - acc_top1: 0.9630 - acc_top2: 0.9906 - 5ms/step\n",
"step 40/157 - loss: 1.4765 - acc_top1: 0.9617 - acc_top2: 0.9902 - 5ms/step\n",
"step 50/157 - loss: 1.4646 - acc_top1: 0.9631 - acc_top2: 0.9903 - 5ms/step\n"
] ]
}, },
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"step 60/157 - loss: 1.5646 - acc_top1: 0.9641 - acc_top2: 0.9906 - 5ms/step\n", "step 110/938 - loss: 1.5484 - acc_top1: 0.8848 - acc_top2: 0.9017 - 14ms/step\n",
"step 70/157 - loss: 1.5167 - acc_top1: 0.9618 - acc_top2: 0.9900 - 5ms/step\n", "step 120/938 - loss: 1.5904 - acc_top1: 0.8840 - acc_top2: 0.9020 - 14ms/step\n",
"step 80/157 - loss: 1.4728 - acc_top1: 0.9635 - acc_top2: 0.9906 - 5ms/step\n", "step 130/938 - loss: 1.5108 - acc_top1: 0.8852 - acc_top2: 0.9025 - 14ms/step\n",
"step 90/157 - loss: 1.5030 - acc_top1: 0.9668 - acc_top2: 0.9917 - 5ms/step\n", "step 140/938 - loss: 1.6199 - acc_top1: 0.8840 - acc_top2: 0.9016 - 14ms/step\n",
"step 100/157 - loss: 1.4612 - acc_top1: 0.9677 - acc_top2: 0.9914 - 5ms/step\n", "step 150/938 - loss: 1.5337 - acc_top1: 0.8842 - acc_top2: 0.9019 - 13ms/step\n",
"step 110/157 - loss: 1.4612 - acc_top1: 0.9689 - acc_top2: 0.9913 - 5ms/step\n", "step 160/938 - loss: 1.6094 - acc_top1: 0.8846 - acc_top2: 0.9023 - 13ms/step\n",
"step 120/157 - loss: 1.4612 - acc_top1: 0.9707 - acc_top2: 0.9919 - 5ms/step\n", "step 170/938 - loss: 1.5653 - acc_top1: 0.8843 - acc_top2: 0.9019 - 13ms/step\n",
"step 130/157 - loss: 1.4621 - acc_top1: 0.9719 - acc_top2: 0.9923 - 5ms/step\n", "step 180/938 - loss: 1.5978 - acc_top1: 0.8835 - acc_top2: 0.9011 - 13ms/step\n",
"step 140/157 - loss: 1.4612 - acc_top1: 0.9734 - acc_top2: 0.9929 - 5ms/step\n", "step 190/938 - loss: 1.5950 - acc_top1: 0.8833 - acc_top2: 0.9012 - 13ms/step\n",
"step 150/157 - loss: 1.4660 - acc_top1: 0.9748 - acc_top2: 0.9933 - 5ms/step\n", "step 200/938 - loss: 1.6422 - acc_top1: 0.8828 - acc_top2: 0.9002 - 13ms/step\n",
"step 157/157 - loss: 1.5215 - acc_top1: 0.9731 - acc_top2: 0.9930 - 5ms/step\n", "step 210/938 - loss: 1.5752 - acc_top1: 0.8831 - acc_top2: 0.9004 - 13ms/step\n",
"Eval samples: 10000\n", "step 220/938 - loss: 1.6635 - acc_top1: 0.8832 - acc_top2: 0.9001 - 13ms/step\n",
"step 230/938 - loss: 1.5726 - acc_top1: 0.8823 - acc_top2: 0.8991 - 13ms/step\n",
"step 240/938 - loss: 1.5702 - acc_top1: 0.8814 - acc_top2: 0.8981 - 13ms/step\n",
"step 250/938 - loss: 1.5748 - acc_top1: 0.8814 - acc_top2: 0.8981 - 14ms/step\n",
"step 260/938 - loss: 1.5589 - acc_top1: 0.8822 - acc_top2: 0.8988 - 14ms/step\n",
"step 270/938 - loss: 1.5902 - acc_top1: 0.8823 - acc_top2: 0.8988 - 14ms/step\n",
"step 280/938 - loss: 1.5646 - acc_top1: 0.8817 - acc_top2: 0.8982 - 14ms/step\n",
"step 290/938 - loss: 1.6280 - acc_top1: 0.8819 - acc_top2: 0.8985 - 14ms/step\n",
"step 300/938 - loss: 1.5697 - acc_top1: 0.8815 - acc_top2: 0.8982 - 14ms/step\n",
"step 310/938 - loss: 1.5540 - acc_top1: 0.8814 - acc_top2: 0.8981 - 14ms/step\n",
"step 320/938 - loss: 1.5598 - acc_top1: 0.8821 - acc_top2: 0.8988 - 14ms/step\n",
"step 330/938 - loss: 1.5498 - acc_top1: 0.8824 - acc_top2: 0.8991 - 14ms/step\n",
"step 340/938 - loss: 1.6276 - acc_top1: 0.8818 - acc_top2: 0.8984 - 14ms/step\n",
"step 350/938 - loss: 1.5129 - acc_top1: 0.8821 - acc_top2: 0.8988 - 14ms/step\n",
"step 360/938 - loss: 1.6158 - acc_top1: 0.8818 - acc_top2: 0.8984 - 14ms/step\n",
"step 370/938 - loss: 1.5300 - acc_top1: 0.8820 - acc_top2: 0.8986 - 14ms/step\n",
"step 380/938 - loss: 1.5718 - acc_top1: 0.8822 - acc_top2: 0.8988 - 14ms/step\n",
"step 390/938 - loss: 1.5898 - acc_top1: 0.8825 - acc_top2: 0.8990 - 14ms/step\n",
"step 400/938 - loss: 1.5177 - acc_top1: 0.8834 - acc_top2: 0.9000 - 14ms/step\n",
"step 410/938 - loss: 1.6493 - acc_top1: 0.8831 - acc_top2: 0.8997 - 14ms/step\n",
"step 420/938 - loss: 1.5071 - acc_top1: 0.8838 - acc_top2: 0.9002 - 14ms/step\n",
"step 430/938 - loss: 1.5982 - acc_top1: 0.8840 - acc_top2: 0.9002 - 14ms/step\n",
"step 440/938 - loss: 1.5649 - acc_top1: 0.8841 - acc_top2: 0.9003 - 14ms/step\n",
"step 450/938 - loss: 1.5555 - acc_top1: 0.8844 - acc_top2: 0.9005 - 14ms/step\n",
"step 460/938 - loss: 1.5536 - acc_top1: 0.8845 - acc_top2: 0.9005 - 14ms/step\n",
"step 470/938 - loss: 1.5401 - acc_top1: 0.8851 - acc_top2: 0.9011 - 14ms/step\n",
"step 480/938 - loss: 1.5549 - acc_top1: 0.8854 - acc_top2: 0.9013 - 14ms/step\n",
"step 490/938 - loss: 1.5596 - acc_top1: 0.8858 - acc_top2: 0.9017 - 14ms/step\n",
"step 500/938 - loss: 1.5059 - acc_top1: 0.8860 - acc_top2: 0.9018 - 14ms/step\n",
"step 510/938 - loss: 1.6073 - acc_top1: 0.8858 - acc_top2: 0.9017 - 14ms/step\n",
"step 520/938 - loss: 1.5588 - acc_top1: 0.8857 - acc_top2: 0.9016 - 14ms/step\n",
"step 530/938 - loss: 1.6165 - acc_top1: 0.8859 - acc_top2: 0.9019 - 14ms/step\n",
"step 540/938 - loss: 1.5884 - acc_top1: 0.8862 - acc_top2: 0.9023 - 14ms/step\n",
"step 550/938 - loss: 1.6552 - acc_top1: 0.8863 - acc_top2: 0.9027 - 14ms/step\n",
"step 560/938 - loss: 1.5529 - acc_top1: 0.8867 - acc_top2: 0.9030 - 14ms/step\n",
"step 570/938 - loss: 1.5441 - acc_top1: 0.8866 - acc_top2: 0.9029 - 14ms/step\n",
"step 580/938 - loss: 1.5438 - acc_top1: 0.8867 - acc_top2: 0.9029 - 14ms/step\n",
"step 590/938 - loss: 1.5761 - acc_top1: 0.8868 - acc_top2: 0.9029 - 14ms/step\n",
"step 600/938 - loss: 1.5384 - acc_top1: 0.8867 - acc_top2: 0.9029 - 14ms/step\n",
"step 610/938 - loss: 1.5858 - acc_top1: 0.8871 - acc_top2: 0.9032 - 14ms/step\n",
"step 620/938 - loss: 1.5524 - acc_top1: 0.8872 - acc_top2: 0.9034 - 14ms/step\n",
"step 630/938 - loss: 1.6182 - acc_top1: 0.8875 - acc_top2: 0.9035 - 14ms/step\n",
"step 640/938 - loss: 1.6326 - acc_top1: 0.8877 - acc_top2: 0.9037 - 14ms/step\n",
"step 650/938 - loss: 1.5871 - acc_top1: 0.8877 - acc_top2: 0.9035 - 14ms/step\n",
"step 660/938 - loss: 1.5403 - acc_top1: 0.8877 - acc_top2: 0.9034 - 14ms/step\n",
"step 670/938 - loss: 1.5539 - acc_top1: 0.8879 - acc_top2: 0.9035 - 14ms/step\n",
"step 680/938 - loss: 1.4918 - acc_top1: 0.8881 - acc_top2: 0.9036 - 14ms/step\n",
"step 690/938 - loss: 1.6007 - acc_top1: 0.8882 - acc_top2: 0.9036 - 14ms/step\n",
"step 700/938 - loss: 1.5539 - acc_top1: 0.8883 - acc_top2: 0.9037 - 14ms/step\n",
"step 710/938 - loss: 1.6036 - acc_top1: 0.8882 - acc_top2: 0.9035 - 14ms/step\n",
"step 720/938 - loss: 1.5943 - acc_top1: 0.8881 - acc_top2: 0.9035 - 14ms/step\n",
"step 730/938 - loss: 1.5714 - acc_top1: 0.8881 - acc_top2: 0.9035 - 14ms/step\n",
"step 740/938 - loss: 1.5095 - acc_top1: 0.8881 - acc_top2: 0.9035 - 14ms/step\n",
"step 750/938 - loss: 1.5069 - acc_top1: 0.8882 - acc_top2: 0.9035 - 14ms/step\n",
"step 760/938 - loss: 1.5816 - acc_top1: 0.8882 - acc_top2: 0.9035 - 14ms/step\n",
"step 770/938 - loss: 1.5855 - acc_top1: 0.8880 - acc_top2: 0.9033 - 14ms/step\n",
"step 780/938 - loss: 1.5599 - acc_top1: 0.8881 - acc_top2: 0.9034 - 14ms/step\n",
"step 790/938 - loss: 1.6029 - acc_top1: 0.8879 - acc_top2: 0.9032 - 14ms/step\n",
"step 800/938 - loss: 1.5839 - acc_top1: 0.8880 - acc_top2: 0.9033 - 14ms/step\n",
"step 810/938 - loss: 1.5545 - acc_top1: 0.8882 - acc_top2: 0.9035 - 14ms/step\n",
"step 820/938 - loss: 1.5458 - acc_top1: 0.8881 - acc_top2: 0.9036 - 14ms/step\n",
"step 830/938 - loss: 1.5911 - acc_top1: 0.8879 - acc_top2: 0.9033 - 14ms/step\n",
"step 840/938 - loss: 1.5845 - acc_top1: 0.8881 - acc_top2: 0.9035 - 14ms/step\n",
"step 850/938 - loss: 1.5628 - acc_top1: 0.8880 - acc_top2: 0.9035 - 14ms/step\n",
"step 860/938 - loss: 1.5596 - acc_top1: 0.8880 - acc_top2: 0.9035 - 14ms/step\n",
"step 870/938 - loss: 1.5843 - acc_top1: 0.8882 - acc_top2: 0.9036 - 14ms/step\n",
"step 880/938 - loss: 1.5393 - acc_top1: 0.8883 - acc_top2: 0.9036 - 14ms/step\n",
"step 890/938 - loss: 1.5382 - acc_top1: 0.8882 - acc_top2: 0.9035 - 14ms/step\n",
"step 900/938 - loss: 1.5910 - acc_top1: 0.8884 - acc_top2: 0.9036 - 14ms/step\n",
"step 910/938 - loss: 1.5682 - acc_top1: 0.8886 - acc_top2: 0.9038 - 14ms/step\n",
"step 920/938 - loss: 1.5736 - acc_top1: 0.8889 - acc_top2: 0.9039 - 14ms/step\n",
"step 930/938 - loss: 1.5283 - acc_top1: 0.8888 - acc_top2: 0.9038 - 14ms/step\n",
"step 938/938 - loss: 1.5582 - acc_top1: 0.8888 - acc_top2: 0.9038 - 14ms/step\n",
"save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/1\n",
"save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/final\n" "save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/final\n"
] ]
} }
], ],
"source": [ "source": [
"model.fit(train_dataset,\n", "model.fit(train_dataset,\n",
" test_dataset,\n",
" epochs=2,\n", " epochs=2,\n",
" batch_size=64,\n", " batch_size=64,\n",
" save_dir='mnist_checkpoint')" " save_dir='mnist_checkpoint')"
...@@ -467,131 +537,59 @@ ...@@ -467,131 +537,59 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 训练方式1结束\n", "### 使用model.evaluate来预测模型"
"以上就是训练方式1,可以非常快速的完成网络模型训练。此外,paddle还可以用下面的方式来完成模型的训练。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3.训练方式2\n",
"方式1可以快速便捷的完成训练,隐藏了训练时的细节。而方式2则可以用最基本的方式,完成模型的训练。具体如下。"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 23, "execution_count": 43,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch: 0, batch_id: 0, loss is: [2.300888], acc is: [0.28125]\n", "Eval begin...\n",
"epoch: 0, batch_id: 100, loss is: [1.6948285], acc is: [0.8125]\n", "step 10/157 - loss: 1.5447 - acc_top1: 0.8953 - acc_top2: 0.9078 - 5ms/step\n",
"epoch: 0, batch_id: 200, loss is: [1.5282547], acc is: [0.96875]\n", "step 20/157 - loss: 1.6185 - acc_top1: 0.8930 - acc_top2: 0.9078 - 5ms/step\n",
"epoch: 0, batch_id: 300, loss is: [1.509404], acc is: [0.96875]\n", "step 30/157 - loss: 1.6497 - acc_top1: 0.8917 - acc_top2: 0.9057 - 5ms/step\n",
"epoch: 0, batch_id: 400, loss is: [1.4973292], acc is: [1.]\n", "step 40/157 - loss: 1.6318 - acc_top1: 0.8902 - acc_top2: 0.9055 - 5ms/step\n",
"epoch: 0, batch_id: 500, loss is: [1.5063374], acc is: [0.984375]\n", "step 50/157 - loss: 1.5533 - acc_top1: 0.8856 - acc_top2: 0.9012 - 5ms/step\n",
"epoch: 0, batch_id: 600, loss is: [1.490077], acc is: [0.984375]\n", "step 60/157 - loss: 1.6212 - acc_top1: 0.8878 - acc_top2: 0.9036 - 5ms/step\n",
"epoch: 0, batch_id: 700, loss is: [1.5206413], acc is: [0.984375]\n", "step 70/157 - loss: 1.5674 - acc_top1: 0.8839 - acc_top2: 0.9002 - 5ms/step\n",
"epoch: 0, batch_id: 800, loss is: [1.5104291], acc is: [1.]\n", "step 80/157 - loss: 1.5409 - acc_top1: 0.8891 - acc_top2: 0.9043 - 5ms/step\n",
"epoch: 0, batch_id: 900, loss is: [1.5216607], acc is: [0.96875]\n", "step 90/157 - loss: 1.6133 - acc_top1: 0.8903 - acc_top2: 0.9045 - 5ms/step\n",
"epoch: 1, batch_id: 0, loss is: [1.4949667], acc is: [0.984375]\n", "step 100/157 - loss: 1.5535 - acc_top1: 0.8909 - acc_top2: 0.9044 - 5ms/step\n",
"epoch: 1, batch_id: 100, loss is: [1.4923338], acc is: [0.96875]\n", "step 110/157 - loss: 1.5690 - acc_top1: 0.8916 - acc_top2: 0.9054 - 5ms/step\n",
"epoch: 1, batch_id: 200, loss is: [1.5026703], acc is: [1.]\n", "step 120/157 - loss: 1.6147 - acc_top1: 0.8926 - acc_top2: 0.9055 - 5ms/step\n",
"epoch: 1, batch_id: 300, loss is: [1.4965419], acc is: [0.984375]\n", "step 130/157 - loss: 1.5203 - acc_top1: 0.8944 - acc_top2: 0.9066 - 5ms/step\n",
"epoch: 1, batch_id: 400, loss is: [1.5270758], acc is: [1.]\n", "step 140/157 - loss: 1.5066 - acc_top1: 0.8952 - acc_top2: 0.9068 - 5ms/step\n",
"epoch: 1, batch_id: 500, loss is: [1.4774603], acc is: [1.]\n", "step 150/157 - loss: 1.5536 - acc_top1: 0.8958 - acc_top2: 0.9072 - 5ms/step\n",
"epoch: 1, batch_id: 600, loss is: [1.4762554], acc is: [0.984375]\n", "step 157/157 - loss: 1.5855 - acc_top1: 0.8956 - acc_top2: 0.9076 - 5ms/step\n",
"epoch: 1, batch_id: 700, loss is: [1.4773959], acc is: [0.984375]\n", "Eval samples: 10000\n"
"epoch: 1, batch_id: 800, loss is: [1.5044193], acc is: [1.]\n",
"epoch: 1, batch_id: 900, loss is: [1.4986757], acc is: [0.96875]\n"
] ]
} },
],
"source": [
"import paddle\n",
"train_loader = paddle.io.DataLoader(train_dataset, places=paddle.CPUPlace(), batch_size=64)\n",
"def train(model):\n",
" model.train()\n",
" epochs = 2\n",
" batch_size = 64\n",
" optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n",
" for epoch in range(epochs):\n",
" for batch_id, data in enumerate(train_loader()):\n",
" x_data = data[0]\n",
" y_data = data[1]\n",
" predicts = model(x_data)\n",
" loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n",
" acc = paddle.metric.accuracy(predicts, y_data, k=2)\n",
" avg_loss = paddle.mean(loss)\n",
" avg_acc = paddle.mean(acc)\n",
" avg_loss.backward()\n",
" if batch_id % 100 == 0:\n",
" print(\"epoch: {}, batch_id: {}, loss is: {}, acc is: {}\".format(epoch, batch_id, avg_loss.numpy(), avg_acc.numpy()))\n",
" optim.minimize(avg_loss)\n",
" model.clear_gradients()\n",
"model = LeNet()\n",
"train(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 对模型进行验证"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{ {
"name": "stdout", "data": {
"output_type": "stream", "text/plain": [
"text": [ "{'loss': [1.585474], 'acc_top1': 0.8956, 'acc_top2': 0.9076}"
"batch_id: 0, loss is: [1.5017498], acc is: [1.]\n", ]
"batch_id: 100, loss is: [1.4783669], acc is: [0.984375]\n", },
"batch_id: 200, loss is: [1.4958509], acc is: [1.]\n", "execution_count": 43,
"batch_id: 300, loss is: [1.4924574], acc is: [1.]\n", "metadata": {},
"batch_id: 400, loss is: [1.4762049], acc is: [1.]\n", "output_type": "execute_result"
"batch_id: 500, loss is: [1.4817208], acc is: [0.984375]\n",
"batch_id: 600, loss is: [1.4763825], acc is: [0.984375]\n",
"batch_id: 700, loss is: [1.4954926], acc is: [1.]\n",
"batch_id: 800, loss is: [1.5220823], acc is: [0.984375]\n",
"batch_id: 900, loss is: [1.4945463], acc is: [0.984375]\n"
]
} }
], ],
"source": [ "source": [
"import paddle\n", "model.evaluate(test_dataset, batch_size=64)"
"test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64)\n",
"def test(model):\n",
" model.eval()\n",
" batch_size = 64\n",
" for batch_id, data in enumerate(train_loader()):\n",
" x_data = data[0]\n",
" y_data = data[1]\n",
" predicts = model(x_data)\n",
" loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n",
" acc = paddle.metric.accuracy(predicts, y_data, k=2)\n",
" avg_loss = paddle.mean(loss)\n",
" avg_acc = paddle.mean(acc)\n",
" avg_loss.backward()\n",
" if batch_id % 100 == 0:\n",
" print(\"batch_id: {}, loss is: {}, acc is: {}\".format(batch_id, avg_loss.numpy(), avg_acc.numpy()))\n",
"test(model)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 训练方式2结束\n", "### 训练方式结束\n",
"以上就是训练方式2,通过这种方式,可以清楚的看到训练和测试中的每一步过程。" "以上就是训练方式二,可以快速、高效的完成网络模型训练与预测。"
] ]
}, },
{ {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册