提交 a30bf285 编写于 作者: D dingjiaweiww

change api name

上级 fcce28ed
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
"from paddle.nn import Layer\n", "from paddle.nn import Layer\n",
"from paddle.vision.datasets import MNIST\n", "from paddle.vision.datasets import MNIST\n",
"from paddle.metric import Accuracy\n", "from paddle.metric import Accuracy\n",
"from paddle.nn import Conv2d,Pool2D,Linear\n", "from paddle.nn import Conv2d,MaxPool2d,Linear\n",
"from paddle.static import InputSpec\n", "from paddle.static import InputSpec\n",
"\n", "\n",
"print(paddle.__version__)\n", "print(paddle.__version__)\n",
...@@ -54,7 +54,7 @@ ...@@ -54,7 +54,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -71,7 +71,7 @@ ...@@ -71,7 +71,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -79,9 +79,9 @@ ...@@ -79,9 +79,9 @@
" def __init__(self):\n", " def __init__(self):\n",
" super(MyModel, self).__init__()\n", " super(MyModel, self).__init__()\n",
" self.conv1 = paddle.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)\n", " self.conv1 = paddle.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)\n",
" self.max_pool1 = Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", " self.max_pool1 = MaxPool2d(kernel_size=2, stride=2)\n",
" self.conv2 = Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)\n", " self.conv2 = Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)\n",
" self.max_pool2 = Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", " self.max_pool2 = MaxPool2d(kernel_size=2, stride=2)\n",
" self.linear1 = Linear(in_features=16*5*5, out_features=120)\n", " self.linear1 = Linear(in_features=16*5*5, out_features=120)\n",
" self.linear2 = Linear(in_features=120, out_features=84)\n", " self.linear2 = Linear(in_features=120, out_features=84)\n",
" self.linear3 = Linear(in_features=84, out_features=10)\n", " self.linear3 = Linear(in_features=84, out_features=10)\n",
...@@ -93,7 +93,7 @@ ...@@ -93,7 +93,7 @@
" x = F.relu(x)\n", " x = F.relu(x)\n",
" x = self.conv2(x)\n", " x = self.conv2(x)\n",
" x = self.max_pool2(x)\n", " x = self.max_pool2(x)\n",
" x = paddle.reshape(x, shape=[-1, 16*5*5])\n", " x = paddle.flatten(x, start_axis=1, stop_axis=-1)\n",
" x = self.linear1(x)\n", " x = self.linear1(x)\n",
" x = F.relu(x)\n", " x = F.relu(x)\n",
" x = self.linear2(x)\n", " x = self.linear2(x)\n",
...@@ -113,7 +113,7 @@ ...@@ -113,7 +113,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 11,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -121,22 +121,22 @@ ...@@ -121,22 +121,22 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Epoch 1/1\n", "Epoch 1/1\n",
"step 100/938 - loss: 1.6444 - acc_top1: 0.5708 - acc_top2: 0.6325 - 16ms/step\n", "step 100/938 - loss: 1.6177 - acc_top1: 0.6119 - acc_top2: 0.6813 - 15ms/step\n",
"step 200/938 - loss: 1.7200 - acc_top1: 0.6946 - acc_top2: 0.7496 - 16ms/step\n", "step 200/938 - loss: 1.7720 - acc_top1: 0.7230 - acc_top2: 0.7788 - 15ms/step\n",
"step 300/938 - loss: 1.5864 - acc_top1: 0.7472 - acc_top2: 0.7947 - 16ms/step\n", "step 300/938 - loss: 1.6114 - acc_top1: 0.7666 - acc_top2: 0.8164 - 15ms/step\n",
"step 400/938 - loss: 1.5369 - acc_top1: 0.7743 - acc_top2: 0.8161 - 16ms/step\n", "step 400/938 - loss: 1.6537 - acc_top1: 0.7890 - acc_top2: 0.8350 - 15ms/step\n",
"step 500/938 - loss: 1.6392 - acc_top1: 0.7935 - acc_top2: 0.8309 - 16ms/step\n", "step 500/938 - loss: 1.5229 - acc_top1: 0.8170 - acc_top2: 0.8619 - 15ms/step\n",
"step 600/938 - loss: 1.5316 - acc_top1: 0.8066 - acc_top2: 0.8411 - 16ms/step\n", "step 600/938 - loss: 1.5269 - acc_top1: 0.8391 - acc_top2: 0.8821 - 15ms/step\n",
"step 700/938 - loss: 1.5870 - acc_top1: 0.8155 - acc_top2: 0.8478 - 16ms/step\n", "step 700/938 - loss: 1.4821 - acc_top1: 0.8561 - acc_top2: 0.8970 - 15ms/step\n",
"step 800/938 - loss: 1.6136 - acc_top1: 0.8230 - acc_top2: 0.8532 - 16ms/step\n", "step 800/938 - loss: 1.4860 - acc_top1: 0.8689 - acc_top2: 0.9081 - 15ms/step\n",
"step 900/938 - loss: 1.5605 - acc_top1: 0.8290 - acc_top2: 0.8574 - 16ms/step\n", "step 900/938 - loss: 1.5032 - acc_top1: 0.8799 - acc_top2: 0.9174 - 15ms/step\n",
"step 938/938 - loss: 1.4618 - acc_top1: 0.8312 - acc_top2: 0.8591 - 16ms/step\n", "step 938/938 - loss: 1.4617 - acc_top1: 0.8835 - acc_top2: 0.9203 - 15ms/step\n",
"save checkpoint at /Users/dingjiawei/Desktop/教程/mnist_checkpoint/0\n", "save checkpoint at /Users/dingjiawei/online_repo/book/paddle2.0_docs/save_model/mnist_checkpoint/0\n",
"Eval begin...\n", "Eval begin...\n",
"step 100/157 - loss: 1.5209 - acc_top1: 0.8700 - acc_top2: 0.8912 - 5ms/step\n", "step 100/157 - loss: 1.4765 - acc_top1: 0.9636 - acc_top2: 0.9891 - 6ms/step\n",
"step 157/157 - loss: 1.5226 - acc_top1: 0.8769 - acc_top2: 0.8939 - 5ms/step\n", "step 157/157 - loss: 1.4612 - acc_top1: 0.9705 - acc_top2: 0.9910 - 6ms/step\n",
"Eval samples: 10000\n", "Eval samples: 10000\n",
"save checkpoint at /Users/dingjiawei/Desktop/教程/mnist_checkpoint/final\n" "save checkpoint at /Users/dingjiawei/online_repo/book/paddle2.0_docs/save_model/mnist_checkpoint/final\n"
] ]
} }
], ],
...@@ -272,7 +272,7 @@ ...@@ -272,7 +272,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 12,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -280,34 +280,34 @@ ...@@ -280,34 +280,34 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Epoch 1/2\n", "Epoch 1/2\n",
"step 100/938 - loss: 1.4777 - acc_top1: 0.9806 - acc_top2: 0.9962 - 16ms/step\n", "step 100/938 - loss: 1.4635 - acc_top1: 0.9650 - acc_top2: 0.9898 - 15ms/step\n",
"step 200/938 - loss: 1.5163 - acc_top1: 0.9795 - acc_top2: 0.9962 - 16ms/step\n", "step 200/938 - loss: 1.5459 - acc_top1: 0.9659 - acc_top2: 0.9897 - 15ms/step\n",
"step 300/938 - loss: 1.4872 - acc_top1: 0.9796 - acc_top2: 0.9957 - 16ms/step\n", "step 300/938 - loss: 1.5109 - acc_top1: 0.9658 - acc_top2: 0.9893 - 15ms/step\n",
"step 400/938 - loss: 1.4717 - acc_top1: 0.9795 - acc_top2: 0.9955 - 16ms/step\n", "step 400/938 - loss: 1.4797 - acc_top1: 0.9664 - acc_top2: 0.9899 - 15ms/step\n",
"step 500/938 - loss: 1.4778 - acc_top1: 0.9794 - acc_top2: 0.9955 - 16ms/step\n", "step 500/938 - loss: 1.4786 - acc_top1: 0.9673 - acc_top2: 0.9902 - 15ms/step\n",
"step 600/938 - loss: 1.4653 - acc_top1: 0.9798 - acc_top2: 0.9955 - 16ms/step\n", "step 600/938 - loss: 1.5082 - acc_top1: 0.9679 - acc_top2: 0.9906 - 15ms/step\n",
"step 700/938 - loss: 1.4768 - acc_top1: 0.9799 - acc_top2: 0.9954 - 16ms/step\n", "step 700/938 - loss: 1.4768 - acc_top1: 0.9687 - acc_top2: 0.9909 - 15ms/step\n",
"step 800/938 - loss: 1.4771 - acc_top1: 0.9804 - acc_top2: 0.9954 - 16ms/step\n", "step 800/938 - loss: 1.4638 - acc_top1: 0.9696 - acc_top2: 0.9913 - 15ms/step\n",
"step 900/938 - loss: 1.4864 - acc_top1: 0.9807 - acc_top2: 0.9954 - 16ms/step\n", "step 900/938 - loss: 1.5058 - acc_top1: 0.9704 - acc_top2: 0.9916 - 15ms/step\n",
"step 938/938 - loss: 1.4612 - acc_top1: 0.9807 - acc_top2: 0.9955 - 16ms/step\n", "step 938/938 - loss: 1.4702 - acc_top1: 0.9708 - acc_top2: 0.9917 - 15ms/step\n",
"Eval begin...\n", "Eval begin...\n",
"step 100/157 - loss: 1.4612 - acc_top1: 0.9762 - acc_top2: 0.9952 - 6ms/step\n", "step 100/157 - loss: 1.4613 - acc_top1: 0.9755 - acc_top2: 0.9944 - 5ms/step\n",
"step 157/157 - loss: 1.4612 - acc_top1: 0.9807 - acc_top2: 0.9959 - 6ms/step\n", "step 157/157 - loss: 1.4612 - acc_top1: 0.9805 - acc_top2: 0.9956 - 5ms/step\n",
"Eval samples: 10000\n", "Eval samples: 10000\n",
"Epoch 2/2\n", "Epoch 2/2\n",
"step 100/938 - loss: 1.4696 - acc_top1: 0.9812 - acc_top2: 0.9942 - 16ms/step\n", "step 100/938 - loss: 1.4832 - acc_top1: 0.9789 - acc_top2: 0.9927 - 15ms/step\n",
"step 200/938 - loss: 1.4619 - acc_top1: 0.9827 - acc_top2: 0.9956 - 16ms/step\n", "step 200/938 - loss: 1.4618 - acc_top1: 0.9779 - acc_top2: 0.9932 - 14ms/step\n",
"step 300/938 - loss: 1.4616 - acc_top1: 0.9826 - acc_top2: 0.9955 - 16ms/step\n", "step 300/938 - loss: 1.4613 - acc_top1: 0.9779 - acc_top2: 0.9929 - 15ms/step\n",
"step 400/938 - loss: 1.4766 - acc_top1: 0.9824 - acc_top2: 0.9954 - 16ms/step\n", "step 400/938 - loss: 1.4765 - acc_top1: 0.9772 - acc_top2: 0.9932 - 15ms/step\n",
"step 500/938 - loss: 1.4770 - acc_top1: 0.9830 - acc_top2: 0.9953 - 16ms/step\n", "step 500/938 - loss: 1.4932 - acc_top1: 0.9775 - acc_top2: 0.9934 - 15ms/step\n",
"step 600/938 - loss: 1.4924 - acc_top1: 0.9831 - acc_top2: 0.9955 - 16ms/step\n", "step 600/938 - loss: 1.4773 - acc_top1: 0.9773 - acc_top2: 0.9936 - 15ms/step\n",
"step 700/938 - loss: 1.4623 - acc_top1: 0.9837 - acc_top2: 0.9959 - 16ms/step\n", "step 700/938 - loss: 1.4612 - acc_top1: 0.9783 - acc_top2: 0.9939 - 15ms/step\n",
"step 800/938 - loss: 1.4768 - acc_top1: 0.9839 - acc_top2: 0.9960 - 16ms/step\n", "step 800/938 - loss: 1.4653 - acc_top1: 0.9779 - acc_top2: 0.9939 - 15ms/step\n",
"step 900/938 - loss: 1.4768 - acc_top1: 0.9838 - acc_top2: 0.9960 - 16ms/step\n", "step 900/938 - loss: 1.4639 - acc_top1: 0.9780 - acc_top2: 0.9939 - 15ms/step\n",
"step 938/938 - loss: 1.4879 - acc_top1: 0.9838 - acc_top2: 0.9960 - 16ms/step\n", "step 938/938 - loss: 1.4678 - acc_top1: 0.9779 - acc_top2: 0.9937 - 15ms/step\n",
"Eval begin...\n", "Eval begin...\n",
"step 100/157 - loss: 1.4612 - acc_top1: 0.9825 - acc_top2: 0.9956 - 6ms/step\n", "step 100/157 - loss: 1.4612 - acc_top1: 0.9733 - acc_top2: 0.9945 - 6ms/step\n",
"step 157/157 - loss: 1.4701 - acc_top1: 0.9854 - acc_top2: 0.9965 - 6ms/step\n", "step 157/157 - loss: 1.4612 - acc_top1: 0.9778 - acc_top2: 0.9952 - 6ms/step\n",
"Eval samples: 10000\n" "Eval samples: 10000\n"
] ]
} }
...@@ -323,13 +323,12 @@ ...@@ -323,13 +323,12 @@
"test_dataset = MNIST(mode='test')\n", "test_dataset = MNIST(mode='test')\n",
"\n", "\n",
"paddle.disable_static()\n", "paddle.disable_static()\n",
"params_path = \"mnist_checkpoint/test\"\n",
"\n", "\n",
"inputs = InputSpec([None, 784], 'float32', 'x')\n", "inputs = InputSpec([None, 784], 'float32', 'x')\n",
"labels = InputSpec([None, 10], 'float32', 'x')\n", "labels = InputSpec([None, 10], 'float32', 'x')\n",
"model = paddle.Model(MyModel(), inputs, labels)\n", "model = paddle.Model(MyModel(), inputs, labels)\n",
"optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n", "optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n",
"model.load(\"../教程/mnist_checkpoint/final\")\n", "model.load(\"./mnist_checkpoint/final\")\n",
"model.prepare( \n", "model.prepare( \n",
" optim,\n", " optim,\n",
" paddle.nn.loss.CrossEntropyLoss(),\n", " paddle.nn.loss.CrossEntropyLoss(),\n",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册