diff --git a/paddle2.0_docs/image_classification/mnist_lenet_classification.ipynb b/paddle2.0_docs/image_classification/mnist_lenet_classification.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..4e544b826c5ad855bbe40ab28e91f40b4e85907c --- /dev/null +++ b/paddle2.0_docs/image_classification/mnist_lenet_classification.ipynb @@ -0,0 +1,666 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MNIST数据集使用LeNet进行图像分类\n", + "本示例教程演示如何在MNIST数据集上用LeNet进行图像分类。\n", + "手写数字的MNIST数据集,包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到1。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist/" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 环境\n", + "本教程基于paddle2.0-alpha编写,如果您的环境不是本版本,请先安装paddle2.0-alpha。" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.0.0-alpha0\n" + ] + } + ], + "source": [ + "import paddle\n", + "print(paddle.__version__)\n", + "paddle.enable_imperative()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 加载数据集\n", + "我们使用飞桨自带的paddle.dataset完成mnist数据集的加载。" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "download training data and load training data\n", + "load finished\n", + "\n" + ] + } + ], + "source": [ + "print('download training data and load training data')\n", + "train_dataset = paddle.incubate.hapi.datasets.MNIST(mode='train')\n", + "test_dataset = paddle.incubate.hapi.datasets.MNIST(mode='test')\n", + "print('load finished')\n", + "print(type(train_dataset[0][0]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "取训练集中的一条数据看一下。" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train_data0 label is: [5]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAIY0lEQVR4nO3dXWhUZxoH8P/jaPxav7KREtNgiooQFvwg1l1cNOr6sQUN3ixR0VUK9cKPXTBYs17ohReLwl5ovCmuZMU1y+IaWpdC0GIuxCJJMLhJa6oWtSl+FVEXvdDK24s5nc5zapKTZ86cOTPz/4Hk/M8xc17w8Z13zpl5RpxzIBquEbkeAOUnFg6ZsHDIhIVDJiwcMmHhkElGhSMiq0WkT0RuisjesAZF8SfW6zgikgDwFYAVAPoBdABY75z7IrzhUVyNzOB33wVw0zn3NQCIyL8A1AEYsHDKyspcVVVVBqekqHV1dX3nnJvq359J4VQA+CYt9wNYONgvVFVVobOzM4NTUtRE5M6b9md9cSwiH4hIp4h0Pnr0KNuno4hkUjjfAqhMy297+xTn3EfOuRrnXM3UqT+b8ShPZVI4HQBmicg7IlICoB7AJ+EMi+LOvMZxzn0vIjsAtAFIADjhnOsNbWQUa5ksjuGc+xTApyGNhfIIrxyTCQuHTFg4ZMLCIRMWDpmwcMiEhUMmLBwyYeGQCQuHTFg4ZMLCIZOMbnIWk9evX6v89OnTwL/b1NSk8osXL1Tu6+tT+dixYyo3NDSo3NLSovKYMWNU3rv3p88N7N+/P/A4h4MzDpmwcMiEhUMmRbPGuXv3rsovX75U+fLlyypfunRJ5SdPnqh85syZ0MZWWVmp8s6dO1VubW1VecKECSrPmTNH5SVLloQ2toFwxiETFg6ZsHDIpGDXOFevXlV52bJlKg/nOkzYEomEygcPHlR5/PjxKm/cuFHladOmqTxlyhSVZ8+enekQh8QZh0xYOGTCwiGTgl3jTJ8+XeWysjKVw1zjLFyom3T41xwXL15UuaSkROVNmzaFNpaocMYhExYOmbBwyKRg1zilpaUqHz58WOVz586pPG/ePJV37do16OPPnTs3tX3hwgV1zH8dpqenR+UjR44M+tj5gDMOmQxZOCJyQkQeikhP2r5SETkvIje8n1MGewwqPEFmnGYAq3379gL4zDk3C8BnXqYiEqjPsYhUAfivc+5XXu4DUOucuyci5QDanXND3iCpqalxcek6+uzZM5X973HZtm2bysePH1f51KlTqe0NGzaEPLr4EJEu51yNf791jfOWc+6et30fwFvmkVFeynhx7JJT1oDTFtvVFiZr4TzwnqLg/Xw40F9ku9rCZL2O8wmAPwL4q/fz49BGFJGJEycOenzSpEmDHk9f89TX16tjI0YU/lWOIC/HWwB8DmC2iPSLyPtIFswKEbkB4HdepiIy5IzjnFs/wKHlIY+F8kjhz6mUFQV7rypTBw4cULmrq0vl9vb21Lb/XtXKlSuzNazY4IxDJiwcMmHhkIn5Ozkt4nSvarhu3bql8vz581PbkydPVseWLl2qck2NvtWzfft2lUUkhBFmR9j3qqjIsXDIhC/HA5oxY4bKzc3Nqe2tW7eqYydPnhw0P3/+XOXNmzerXF5ebh1mZDjjkAkLh0xYOGTCNY7RunXrUtszZ85Ux3bv3q2y/5ZEY2Ojynfu6O+E37dvn8oVFRXmcWYLZxwyYeGQCQuHTHjLIQv8rW39HzfesmWLyv5/g+XL9Xvkzp8/H9rYhou3HChULBwyYeGQCdc4OTB69GiVX716pfKoUaNUbmtrU7m2tjYr43oTrnEoVCwcMmHhkAnvVYXg2rVrKvu/kqijo0Nl/5rGr7q6WuXFixdnMLrs4IxDJiwcMmHhkAnXOAH5v+L56NGjqe2zZ8+qY/fv3x/WY48cqf8Z/O85jmPblPiNiPJCkP44lSJyUUS+EJFeEfmTt58ta4tYkBnnewC7nXPVAH4NYLuIVIMta4takMZK9wDc87b/LyJfAqgAUAeg1vtr/wDQDuDDrIwyAv51yenTp1VuampS+fbt2+ZzLViwQGX/e4zXrl1rfuyoDGuN4/U7ngfgCtiytqgFLhwR+QWA/wD4s3NOdZcerGUt29UWpkCFIyKjkCyafzrnfnztGahlLdvVFqYh1ziS7MHxdwBfOuf+lnYor1rWPnjwQOXe3l6Vd+zYofL169fN5/J/1eKePXtUrqurUzmO12mGEuQC4CIAmwD8T0S6vX1/QbJg/u21r70D4A9ZGSHFUpBXVZcADNT5hy1ri1T+zZEUCwVzr+rx48cq+782qLu7W2V/a7bhWrRoUWrb/1nxVatWqTx27NiMzhVHnHHIhIVDJiwcMsmrNc6VK1dS24cOHVLH/O/r7e/vz+hc48aNU9n/ddLp95f8XxddDDjjkAkLh0zy6qmqtbX1jdtB+D9ysmbNGpUTiYTKDQ0NKvu7pxc7zjhkwsIhExYOmbDNCQ2KbU4oVCwcMmHhkAkLh0xYOGTCwiETFg6ZsHDIhIVDJiwcMmHhkEmk96pE5BGSn/osA/BdZCcenriOLVfjmu6c+9mH/iMtnNRJRTrfdOMsDuI6triNi09VZMLCIZNcFc5HOTpvEHEdW6zGlZM1DuU/PlWRSaSFIyKrRaRPRG6KSE7b24rICRF5KCI9afti0bs5H3pLR1Y4IpIAcAzA7wFUA1jv9UvOlWYAq3374tK7Of69pZ1zkfwB8BsAbWm5EUBjVOcfYExVAHrSch+Acm+7HEBfLseXNq6PAayI0/iifKqqAPBNWu739sVJ7Ho3x7W3NBfHA3DJ/9Y5fclp7S0dhSgL51sAlWn5bW9fnATq3RyFTHpLRyHKwukAMEtE3hGREgD1SPZKjpMfezcDOezdHKC3NJDr3tIRL/LeA/AVgFsA9uV4wdmC5JebvEJyvfU+gF8i+WrlBoALAEpzNLbfIvk0dA1At/fnvbiMzznHK8dkw8UxmbBwyISFQyYsHDJh4ZAJC4dMWDhkwsIhkx8AyyZIbAmqetUAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]\n", + "train_data0 = train_data0.transpose(1,2,0)\n", + "plt.figure(figsize=(2,2))\n", + "plt.imshow(train_data0, cmap=plt.cm.binary)\n", + "print('train_data0 label is: ' + str(train_label_0))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2.组网&训练方案1\n", + "paddle支持用model类,直接完成模型的训练,具体如下。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 首先需要继承Model来自定义LeNet网络。" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "class LeNet(paddle.incubate.hapi.model.Model):\n", + " def __init__(self):\n", + " super(LeNet, self).__init__()\n", + " self.conv1 = paddle.nn.Conv2D(num_channels=1, num_filters=6, filter_size=5, stride=1, padding=2, act='relu')\n", + " self.max_pool1 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", + " self.conv2 = paddle.nn.Conv2D(num_channels=6, num_filters=16, filter_size=5, stride=1, act='relu')\n", + " self.max_pool2 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", + " self.linear1 = paddle.nn.Linear(input_dim=16*5*5, output_dim=120, act='relu')\n", + " self.linear2 = paddle.nn.Linear(input_dim=120, output_dim=84, act='relu')\n", + " self.linear3 = paddle.nn.Linear(input_dim=84, output_dim=10, act='softmax')\n", + "\n", + " def forward(self, x):\n", + " x = self.conv1(x)\n", + " x = self.max_pool1(x)\n", + " x = self.conv2(x)\n", + " x = self.max_pool2(x)\n", + " x = paddle.reshape(x, shape=[-1, 16*5*5])\n", + " x = self.linear1(x)\n", + " x = self.linear2(x)\n", + " x = self.linear3(x)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 初始化Model,并定义相关的参数。" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "from paddle.incubate.hapi.model import Input\n", + "from paddle.incubate.hapi.loss import CrossEntropy\n", + "from paddle.incubate.hapi.metrics import Accuracy\n", + "\n", + "inputs = [Input([None, 1, 28, 28], 'float32', name='image')]\n", + "labels = [Input([None, 1], 'int64', name='label')]\n", + "model = LeNet()\n", + "optim = paddle.optimizer.Adam(learning_rate=0.001, parameter_list=model.parameters())\n", + "\n", + "model.prepare(\n", + " optim,\n", + " CrossEntropy(),\n", + " Accuracy(topk=(1, 2)),\n", + " inputs=inputs,\n", + " labels=labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 使用fit来训练模型" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/2\n", + "step 10/938 - loss: 2.1912 - acc_top1: 0.2719 - acc_top2: 0.4109 - 16ms/step\n", + "step 20/938 - loss: 1.6389 - acc_top1: 0.4109 - acc_top2: 0.5367 - 15ms/step\n", + "step 30/938 - loss: 1.1486 - acc_top1: 0.4797 - acc_top2: 0.6135 - 15ms/step\n", + "step 40/938 - loss: 0.7755 - acc_top1: 0.5484 - acc_top2: 0.6770 - 15ms/step\n", + "step 50/938 - loss: 0.7651 - acc_top1: 0.5975 - acc_top2: 0.7266 - 15ms/step\n", + "step 60/938 - loss: 0.3837 - acc_top1: 0.6393 - acc_top2: 0.7617 - 15ms/step\n", + "step 70/938 - loss: 0.6532 - acc_top1: 0.6712 - acc_top2: 0.7888 - 15ms/step\n", + "step 80/938 - loss: 0.3394 - acc_top1: 0.6969 - acc_top2: 0.8107 - 15ms/step\n", + "step 90/938 - loss: 0.2527 - acc_top1: 0.7189 - acc_top2: 0.8283 - 15ms/step\n", + "step 100/938 - loss: 0.2055 - acc_top1: 0.7389 - acc_top2: 0.8427 - 14ms/step\n", + "step 110/938 - loss: 0.3987 - acc_top1: 0.7531 - acc_top2: 0.8536 - 14ms/step\n", + "step 120/938 - loss: 0.2372 - acc_top1: 0.7660 - acc_top2: 0.8622 - 14ms/step\n", + "step 130/938 - loss: 0.4071 - acc_top1: 0.7780 - acc_top2: 0.8708 - 14ms/step\n", + "step 140/938 - loss: 0.1315 - acc_top1: 0.7895 - acc_top2: 0.8780 - 14ms/step\n", + "step 150/938 - loss: 0.3168 - acc_top1: 0.7981 - acc_top2: 0.8843 - 15ms/step\n", + "step 160/938 - loss: 0.2782 - acc_top1: 0.8063 - acc_top2: 0.8901 - 15ms/step\n", + "step 170/938 - loss: 0.2030 - acc_top1: 0.8144 - acc_top2: 0.8956 - 15ms/step\n", + "step 180/938 - loss: 0.2336 - acc_top1: 0.8203 - acc_top2: 0.9000 - 15ms/step\n", + "step 190/938 - loss: 0.5915 - acc_top1: 0.8260 - acc_top2: 0.9038 - 15ms/step\n", + "step 200/938 - loss: 0.4995 - acc_top1: 0.8310 - acc_top2: 0.9076 - 15ms/step\n", + "step 210/938 - loss: 0.2190 - acc_top1: 0.8359 - acc_top2: 0.9106 - 15ms/step\n", + "step 220/938 - loss: 0.1835 - acc_top1: 0.8397 - acc_top2: 0.9130 - 15ms/step\n", + "step 230/938 - loss: 0.1321 - acc_top1: 0.8442 - acc_top2: 0.9159 - 15ms/step\n", + "step 240/938 - loss: 0.2406 - acc_top1: 0.8478 - acc_top2: 0.9183 - 15ms/step\n", + "step 250/938 - loss: 0.1245 - acc_top1: 0.8518 - acc_top2: 0.9209 - 15ms/step\n", + "step 260/938 - loss: 0.1570 - acc_top1: 0.8559 - acc_top2: 0.9236 - 15ms/step\n", + "step 270/938 - loss: 0.1647 - acc_top1: 0.8593 - acc_top2: 0.9259 - 15ms/step\n", + "step 280/938 - loss: 0.1876 - acc_top1: 0.8625 - acc_top2: 0.9281 - 14ms/step\n", + "step 290/938 - loss: 0.2247 - acc_top1: 0.8650 - acc_top2: 0.9300 - 15ms/step\n", + "step 300/938 - loss: 0.2070 - acc_top1: 0.8679 - acc_top2: 0.9318 - 15ms/step\n", + "step 310/938 - loss: 0.1122 - acc_top1: 0.8701 - acc_top2: 0.9333 - 14ms/step\n", + "step 320/938 - loss: 0.0857 - acc_top1: 0.8729 - acc_top2: 0.9351 - 14ms/step\n", + "step 330/938 - loss: 0.2414 - acc_top1: 0.8751 - acc_top2: 0.9365 - 14ms/step\n", + "step 340/938 - loss: 0.2631 - acc_top1: 0.8774 - acc_top2: 0.9380 - 14ms/step\n", + "step 350/938 - loss: 0.1347 - acc_top1: 0.8796 - acc_top2: 0.9396 - 14ms/step\n", + "step 360/938 - loss: 0.2295 - acc_top1: 0.8816 - acc_top2: 0.9409 - 14ms/step\n", + "step 370/938 - loss: 0.2971 - acc_top1: 0.8842 - acc_top2: 0.9423 - 14ms/step\n", + "step 380/938 - loss: 0.1623 - acc_top1: 0.8863 - acc_top2: 0.9436 - 14ms/step\n", + "step 390/938 - loss: 0.1020 - acc_top1: 0.8880 - acc_top2: 0.9448 - 14ms/step\n", + "step 400/938 - loss: 0.0716 - acc_top1: 0.8895 - acc_top2: 0.9459 - 14ms/step\n", + "step 410/938 - loss: 0.0889 - acc_top1: 0.8914 - acc_top2: 0.9469 - 14ms/step\n", + "step 420/938 - loss: 0.1010 - acc_top1: 0.8931 - acc_top2: 0.9478 - 14ms/step\n", + "step 430/938 - loss: 0.0486 - acc_top1: 0.8945 - acc_top2: 0.9487 - 14ms/step\n", + "step 440/938 - loss: 0.1723 - acc_top1: 0.8958 - acc_top2: 0.9495 - 14ms/step\n", + "step 450/938 - loss: 0.2270 - acc_top1: 0.8974 - acc_top2: 0.9503 - 14ms/step\n", + "step 460/938 - loss: 0.1197 - acc_top1: 0.8987 - acc_top2: 0.9512 - 14ms/step\n", + "step 470/938 - loss: 0.2837 - acc_top1: 0.9002 - acc_top2: 0.9519 - 14ms/step\n", + "step 480/938 - loss: 0.1091 - acc_top1: 0.9017 - acc_top2: 0.9528 - 14ms/step\n", + "step 490/938 - loss: 0.1397 - acc_top1: 0.9029 - acc_top2: 0.9535 - 14ms/step\n", + "step 500/938 - loss: 0.1034 - acc_top1: 0.9040 - acc_top2: 0.9543 - 14ms/step\n", + "step 510/938 - loss: 0.0095 - acc_top1: 0.9054 - acc_top2: 0.9550 - 14ms/step\n", + "step 520/938 - loss: 0.0092 - acc_top1: 0.9068 - acc_top2: 0.9558 - 14ms/step\n", + "step 530/938 - loss: 0.0633 - acc_top1: 0.9077 - acc_top2: 0.9565 - 14ms/step\n", + "step 540/938 - loss: 0.0936 - acc_top1: 0.9086 - acc_top2: 0.9571 - 14ms/step\n", + "step 550/938 - loss: 0.1180 - acc_top1: 0.9097 - acc_top2: 0.9577 - 14ms/step\n", + "step 560/938 - loss: 0.1600 - acc_top1: 0.9106 - acc_top2: 0.9583 - 14ms/step\n", + "step 570/938 - loss: 0.1338 - acc_top1: 0.9118 - acc_top2: 0.9590 - 14ms/step\n", + "step 580/938 - loss: 0.0496 - acc_top1: 0.9128 - acc_top2: 0.9595 - 14ms/step\n", + "step 590/938 - loss: 0.0651 - acc_top1: 0.9138 - acc_top2: 0.9600 - 14ms/step\n", + "step 600/938 - loss: 0.1306 - acc_top1: 0.9147 - acc_top2: 0.9605 - 14ms/step\n", + "step 610/938 - loss: 0.0744 - acc_top1: 0.9157 - acc_top2: 0.9610 - 14ms/step\n", + "step 620/938 - loss: 0.1679 - acc_top1: 0.9166 - acc_top2: 0.9616 - 14ms/step\n", + "step 630/938 - loss: 0.0789 - acc_top1: 0.9173 - acc_top2: 0.9621 - 14ms/step\n", + "step 640/938 - loss: 0.0767 - acc_top1: 0.9182 - acc_top2: 0.9626 - 14ms/step\n", + "step 650/938 - loss: 0.1776 - acc_top1: 0.9188 - acc_top2: 0.9630 - 14ms/step\n", + "step 660/938 - loss: 0.1371 - acc_top1: 0.9196 - acc_top2: 0.9634 - 14ms/step\n", + "step 670/938 - loss: 0.1011 - acc_top1: 0.9204 - acc_top2: 0.9639 - 14ms/step\n", + "step 680/938 - loss: 0.0447 - acc_top1: 0.9209 - acc_top2: 0.9642 - 14ms/step\n", + "step 690/938 - loss: 0.0230 - acc_top1: 0.9217 - acc_top2: 0.9646 - 14ms/step\n", + "step 700/938 - loss: 0.0541 - acc_top1: 0.9224 - acc_top2: 0.9649 - 14ms/step\n", + "step 710/938 - loss: 0.1395 - acc_top1: 0.9231 - acc_top2: 0.9653 - 14ms/step\n", + "step 720/938 - loss: 0.0426 - acc_top1: 0.9238 - acc_top2: 0.9657 - 14ms/step\n", + "step 730/938 - loss: 0.0540 - acc_top1: 0.9247 - acc_top2: 0.9660 - 14ms/step\n", + "step 740/938 - loss: 0.1132 - acc_top1: 0.9253 - acc_top2: 0.9664 - 14ms/step\n", + "step 750/938 - loss: 0.0088 - acc_top1: 0.9261 - acc_top2: 0.9668 - 14ms/step\n", + "step 760/938 - loss: 0.0282 - acc_top1: 0.9266 - acc_top2: 0.9672 - 14ms/step\n", + "step 770/938 - loss: 0.1233 - acc_top1: 0.9272 - acc_top2: 0.9675 - 14ms/step\n", + "step 780/938 - loss: 0.2208 - acc_top1: 0.9275 - acc_top2: 0.9677 - 14ms/step\n", + "step 790/938 - loss: 0.0599 - acc_top1: 0.9281 - acc_top2: 0.9680 - 14ms/step\n", + "step 800/938 - loss: 0.0270 - acc_top1: 0.9287 - acc_top2: 0.9683 - 14ms/step\n", + "step 810/938 - loss: 0.1546 - acc_top1: 0.9291 - acc_top2: 0.9687 - 14ms/step\n", + "step 820/938 - loss: 0.0252 - acc_top1: 0.9297 - acc_top2: 0.9689 - 14ms/step\n", + "step 830/938 - loss: 0.0276 - acc_top1: 0.9304 - acc_top2: 0.9693 - 14ms/step\n", + "step 840/938 - loss: 0.0620 - acc_top1: 0.9309 - acc_top2: 0.9695 - 14ms/step\n", + "step 850/938 - loss: 0.0505 - acc_top1: 0.9314 - acc_top2: 0.9699 - 14ms/step\n", + "step 860/938 - loss: 0.0156 - acc_top1: 0.9319 - acc_top2: 0.9701 - 14ms/step\n", + "step 870/938 - loss: 0.0229 - acc_top1: 0.9325 - acc_top2: 0.9704 - 14ms/step\n", + "step 880/938 - loss: 0.0498 - acc_top1: 0.9330 - acc_top2: 0.9707 - 14ms/step\n", + "step 890/938 - loss: 0.0183 - acc_top1: 0.9335 - acc_top2: 0.9710 - 14ms/step\n", + "step 900/938 - loss: 0.1282 - acc_top1: 0.9339 - acc_top2: 0.9712 - 14ms/step\n", + "step 910/938 - loss: 0.0426 - acc_top1: 0.9342 - acc_top2: 0.9715 - 14ms/step\n", + "step 920/938 - loss: 0.0641 - acc_top1: 0.9347 - acc_top2: 0.9717 - 14ms/step\n", + "step 930/938 - loss: 0.0745 - acc_top1: 0.9351 - acc_top2: 0.9719 - 14ms/step\n", + "step 938/938 - loss: 0.0118 - acc_top1: 0.9354 - acc_top2: 0.9721 - 14ms/step\n", + "save checkpoint at mnist_checkpoint/0\n", + "Eval begin...\n", + "step 10/157 - loss: 0.1032 - acc_top1: 0.9828 - acc_top2: 0.9969 - 5ms/step\n", + "step 20/157 - loss: 0.2664 - acc_top1: 0.9781 - acc_top2: 0.9953 - 5ms/step\n", + "step 30/157 - loss: 0.1626 - acc_top1: 0.9766 - acc_top2: 0.9943 - 5ms/step\n", + "step 40/157 - loss: 0.0247 - acc_top1: 0.9734 - acc_top2: 0.9926 - 5ms/step\n", + "step 50/157 - loss: 0.0225 - acc_top1: 0.9738 - acc_top2: 0.9925 - 5ms/step\n", + "step 60/157 - loss: 0.2119 - acc_top1: 0.9737 - acc_top2: 0.9927 - 5ms/step\n", + "step 70/157 - loss: 0.0559 - acc_top1: 0.9723 - acc_top2: 0.9920 - 5ms/step\n", + "step 80/157 - loss: 0.0329 - acc_top1: 0.9725 - acc_top2: 0.9918 - 5ms/step\n", + "step 90/157 - loss: 0.1064 - acc_top1: 0.9741 - acc_top2: 0.9925 - 5ms/step\n", + "step 100/157 - loss: 0.0027 - acc_top1: 0.9744 - acc_top2: 0.9923 - 5ms/step\n", + "step 110/157 - loss: 0.0044 - acc_top1: 0.9750 - acc_top2: 0.9925 - 5ms/step\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step 120/157 - loss: 0.0093 - acc_top1: 0.9768 - acc_top2: 0.9931 - 5ms/step\n", + "step 130/157 - loss: 0.1247 - acc_top1: 0.9774 - acc_top2: 0.9935 - 5ms/step\n", + "step 140/157 - loss: 0.0031 - acc_top1: 0.9785 - acc_top2: 0.9940 - 5ms/step\n", + "step 150/157 - loss: 0.0495 - acc_top1: 0.9794 - acc_top2: 0.9944 - 5ms/step\n", + "step 157/157 - loss: 0.0020 - acc_top1: 0.9790 - acc_top2: 0.9944 - 5ms/step\n", + "Eval samples: 10000\n", + "Epoch 2/2\n", + "step 10/938 - loss: 0.1735 - acc_top1: 0.9766 - acc_top2: 0.9938 - 16ms/step\n", + "step 20/938 - loss: 0.0723 - acc_top1: 0.9750 - acc_top2: 0.9922 - 15ms/step\n", + "step 30/938 - loss: 0.0593 - acc_top1: 0.9781 - acc_top2: 0.9927 - 15ms/step\n", + "step 40/938 - loss: 0.1243 - acc_top1: 0.9793 - acc_top2: 0.9938 - 15ms/step\n", + "step 50/938 - loss: 0.0127 - acc_top1: 0.9797 - acc_top2: 0.9944 - 15ms/step\n", + "step 60/938 - loss: 0.0319 - acc_top1: 0.9779 - acc_top2: 0.9938 - 15ms/step\n", + "step 70/938 - loss: 0.0404 - acc_top1: 0.9783 - acc_top2: 0.9946 - 15ms/step\n", + "step 80/938 - loss: 0.1120 - acc_top1: 0.9781 - acc_top2: 0.9943 - 15ms/step\n", + "step 90/938 - loss: 0.0222 - acc_top1: 0.9780 - acc_top2: 0.9944 - 15ms/step\n", + "step 100/938 - loss: 0.0726 - acc_top1: 0.9788 - acc_top2: 0.9948 - 15ms/step\n", + "step 110/938 - loss: 0.0255 - acc_top1: 0.9790 - acc_top2: 0.9952 - 15ms/step\n", + "step 120/938 - loss: 0.2556 - acc_top1: 0.9790 - acc_top2: 0.9948 - 15ms/step\n", + "step 130/938 - loss: 0.0795 - acc_top1: 0.9786 - acc_top2: 0.9945 - 15ms/step\n", + "step 140/938 - loss: 0.1106 - acc_top1: 0.9785 - acc_top2: 0.9944 - 15ms/step\n", + "step 150/938 - loss: 0.0564 - acc_top1: 0.9784 - acc_top2: 0.9946 - 15ms/step\n", + "step 160/938 - loss: 0.1016 - acc_top1: 0.9784 - acc_top2: 0.9947 - 15ms/step\n", + "step 170/938 - loss: 0.0665 - acc_top1: 0.9785 - acc_top2: 0.9946 - 15ms/step\n", + "step 180/938 - loss: 0.0443 - acc_top1: 0.9788 - acc_top2: 0.9946 - 15ms/step\n", + "step 190/938 - loss: 0.0696 - acc_top1: 0.9789 - acc_top2: 0.9947 - 15ms/step\n", + "step 200/938 - loss: 0.0552 - acc_top1: 0.9791 - acc_top2: 0.9948 - 15ms/step\n", + "step 210/938 - loss: 0.1540 - acc_top1: 0.9789 - acc_top2: 0.9946 - 15ms/step\n", + "step 220/938 - loss: 0.0422 - acc_top1: 0.9791 - acc_top2: 0.9947 - 15ms/step\n", + "step 230/938 - loss: 0.2994 - acc_top1: 0.9791 - acc_top2: 0.9946 - 15ms/step\n", + "step 240/938 - loss: 0.0246 - acc_top1: 0.9791 - acc_top2: 0.9946 - 15ms/step\n", + "step 250/938 - loss: 0.0802 - acc_top1: 0.9788 - acc_top2: 0.9946 - 15ms/step\n", + "step 260/938 - loss: 0.1142 - acc_top1: 0.9787 - acc_top2: 0.9947 - 15ms/step\n", + "step 270/938 - loss: 0.0195 - acc_top1: 0.9785 - acc_top2: 0.9946 - 15ms/step\n", + "step 280/938 - loss: 0.0559 - acc_top1: 0.9785 - acc_top2: 0.9944 - 15ms/step\n", + "step 290/938 - loss: 0.1101 - acc_top1: 0.9786 - acc_top2: 0.9943 - 15ms/step\n", + "step 300/938 - loss: 0.0078 - acc_top1: 0.9786 - acc_top2: 0.9943 - 15ms/step\n", + "step 310/938 - loss: 0.0877 - acc_top1: 0.9789 - acc_top2: 0.9944 - 15ms/step\n", + "step 320/938 - loss: 0.0919 - acc_top1: 0.9790 - acc_top2: 0.9945 - 15ms/step\n", + "step 330/938 - loss: 0.0395 - acc_top1: 0.9789 - acc_top2: 0.9945 - 15ms/step\n", + "step 340/938 - loss: 0.1892 - acc_top1: 0.9787 - acc_top2: 0.9945 - 15ms/step\n", + "step 350/938 - loss: 0.0457 - acc_top1: 0.9784 - acc_top2: 0.9944 - 15ms/step\n", + "step 360/938 - loss: 0.1036 - acc_top1: 0.9786 - acc_top2: 0.9944 - 15ms/step\n", + "step 370/938 - loss: 0.0614 - acc_top1: 0.9785 - acc_top2: 0.9944 - 15ms/step\n", + "step 380/938 - loss: 0.2316 - acc_top1: 0.9787 - acc_top2: 0.9944 - 15ms/step\n", + "step 390/938 - loss: 0.0126 - acc_top1: 0.9788 - acc_top2: 0.9945 - 15ms/step\n", + "step 400/938 - loss: 0.0614 - acc_top1: 0.9789 - acc_top2: 0.9946 - 15ms/step\n", + "step 410/938 - loss: 0.0374 - acc_top1: 0.9788 - acc_top2: 0.9945 - 15ms/step\n", + "step 420/938 - loss: 0.0924 - acc_top1: 0.9788 - acc_top2: 0.9945 - 15ms/step\n", + "step 430/938 - loss: 0.0151 - acc_top1: 0.9791 - acc_top2: 0.9946 - 15ms/step\n", + "step 440/938 - loss: 0.0223 - acc_top1: 0.9791 - acc_top2: 0.9947 - 15ms/step\n", + "step 450/938 - loss: 0.0111 - acc_top1: 0.9793 - acc_top2: 0.9947 - 15ms/step\n", + "step 460/938 - loss: 0.0112 - acc_top1: 0.9793 - acc_top2: 0.9947 - 15ms/step\n", + "step 470/938 - loss: 0.0239 - acc_top1: 0.9794 - acc_top2: 0.9947 - 15ms/step\n", + "step 480/938 - loss: 0.0821 - acc_top1: 0.9795 - acc_top2: 0.9948 - 15ms/step\n", + "step 490/938 - loss: 0.0493 - acc_top1: 0.9796 - acc_top2: 0.9948 - 15ms/step\n", + "step 500/938 - loss: 0.0627 - acc_top1: 0.9797 - acc_top2: 0.9949 - 15ms/step\n", + "step 510/938 - loss: 0.0331 - acc_top1: 0.9797 - acc_top2: 0.9949 - 15ms/step\n", + "step 520/938 - loss: 0.0831 - acc_top1: 0.9797 - acc_top2: 0.9949 - 15ms/step\n", + "step 530/938 - loss: 0.0687 - acc_top1: 0.9796 - acc_top2: 0.9949 - 15ms/step\n", + "step 540/938 - loss: 0.1556 - acc_top1: 0.9794 - acc_top2: 0.9949 - 15ms/step\n", + "step 550/938 - loss: 0.2394 - acc_top1: 0.9795 - acc_top2: 0.9950 - 15ms/step\n", + "step 560/938 - loss: 0.0353 - acc_top1: 0.9794 - acc_top2: 0.9950 - 15ms/step\n", + "step 570/938 - loss: 0.0179 - acc_top1: 0.9794 - acc_top2: 0.9951 - 15ms/step\n", + "step 580/938 - loss: 0.0307 - acc_top1: 0.9796 - acc_top2: 0.9951 - 15ms/step\n", + "step 590/938 - loss: 0.0806 - acc_top1: 0.9796 - acc_top2: 0.9952 - 15ms/step\n", + "step 600/938 - loss: 0.0320 - acc_top1: 0.9796 - acc_top2: 0.9953 - 15ms/step\n", + "step 610/938 - loss: 0.0201 - acc_top1: 0.9798 - acc_top2: 0.9953 - 15ms/step\n", + "step 620/938 - loss: 0.1524 - acc_top1: 0.9797 - acc_top2: 0.9953 - 15ms/step\n", + "step 630/938 - loss: 0.0062 - acc_top1: 0.9797 - acc_top2: 0.9953 - 15ms/step\n", + "step 640/938 - loss: 0.0908 - acc_top1: 0.9798 - acc_top2: 0.9953 - 15ms/step\n", + "step 650/938 - loss: 0.0467 - acc_top1: 0.9799 - acc_top2: 0.9954 - 15ms/step\n", + "step 660/938 - loss: 0.0156 - acc_top1: 0.9801 - acc_top2: 0.9954 - 15ms/step\n", + "step 670/938 - loss: 0.0318 - acc_top1: 0.9802 - acc_top2: 0.9955 - 15ms/step\n", + "step 680/938 - loss: 0.0133 - acc_top1: 0.9804 - acc_top2: 0.9955 - 15ms/step\n", + "step 690/938 - loss: 0.0651 - acc_top1: 0.9805 - acc_top2: 0.9956 - 15ms/step\n", + "step 700/938 - loss: 0.0052 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n", + "step 710/938 - loss: 0.1208 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n", + "step 720/938 - loss: 0.1519 - acc_top1: 0.9805 - acc_top2: 0.9956 - 15ms/step\n", + "step 730/938 - loss: 0.0954 - acc_top1: 0.9805 - acc_top2: 0.9955 - 15ms/step\n", + "step 740/938 - loss: 0.0059 - acc_top1: 0.9806 - acc_top2: 0.9955 - 15ms/step\n", + "step 750/938 - loss: 0.1000 - acc_top1: 0.9805 - acc_top2: 0.9955 - 15ms/step\n", + "step 760/938 - loss: 0.0629 - acc_top1: 0.9805 - acc_top2: 0.9955 - 15ms/step\n", + "step 770/938 - loss: 0.0182 - acc_top1: 0.9804 - acc_top2: 0.9955 - 15ms/step\n", + "step 780/938 - loss: 0.0215 - acc_top1: 0.9804 - acc_top2: 0.9955 - 15ms/step\n", + "step 790/938 - loss: 0.0418 - acc_top1: 0.9804 - acc_top2: 0.9956 - 15ms/step\n", + "step 800/938 - loss: 0.0132 - acc_top1: 0.9805 - acc_top2: 0.9956 - 15ms/step\n", + "step 810/938 - loss: 0.0546 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n", + "step 820/938 - loss: 0.0373 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n", + "step 830/938 - loss: 0.0965 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n", + "step 840/938 - loss: 0.0143 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n", + "step 850/938 - loss: 0.0578 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n", + "step 860/938 - loss: 0.0205 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n", + "step 870/938 - loss: 0.0384 - acc_top1: 0.9808 - acc_top2: 0.9956 - 15ms/step\n", + "step 880/938 - loss: 0.0157 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n", + "step 890/938 - loss: 0.0457 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n", + "step 900/938 - loss: 0.0202 - acc_top1: 0.9808 - acc_top2: 0.9956 - 15ms/step\n", + "step 910/938 - loss: 0.0240 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n", + "step 920/938 - loss: 0.0585 - acc_top1: 0.9808 - acc_top2: 0.9956 - 15ms/step\n", + "step 930/938 - loss: 0.0414 - acc_top1: 0.9809 - acc_top2: 0.9956 - 15ms/step\n", + "step 938/938 - loss: 0.0180 - acc_top1: 0.9809 - acc_top2: 0.9956 - 15ms/step\n", + "save checkpoint at mnist_checkpoint/1\n", + "Eval begin...\n", + "step 10/157 - loss: 0.1093 - acc_top1: 0.9828 - acc_top2: 0.9984 - 5ms/step\n", + "step 20/157 - loss: 0.2292 - acc_top1: 0.9789 - acc_top2: 0.9969 - 5ms/step\n", + "step 30/157 - loss: 0.1203 - acc_top1: 0.9797 - acc_top2: 0.9969 - 5ms/step\n", + "step 40/157 - loss: 0.0068 - acc_top1: 0.9773 - acc_top2: 0.9961 - 5ms/step\n", + "step 50/157 - loss: 0.0049 - acc_top1: 0.9775 - acc_top2: 0.9959 - 5ms/step\n", + "step 60/157 - loss: 0.0399 - acc_top1: 0.9779 - acc_top2: 0.9956 - 5ms/step\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step 70/157 - loss: 0.0299 - acc_top1: 0.9768 - acc_top2: 0.9953 - 5ms/step\n", + "step 80/157 - loss: 0.0108 - acc_top1: 0.9771 - acc_top2: 0.9955 - 5ms/step\n", + "step 90/157 - loss: 0.0209 - acc_top1: 0.9793 - acc_top2: 0.9958 - 5ms/step\n", + "step 100/157 - loss: 0.0031 - acc_top1: 0.9806 - acc_top2: 0.9962 - 5ms/step\n", + "step 110/157 - loss: 4.0509e-04 - acc_top1: 0.9808 - acc_top2: 0.9962 - 5ms/step\n", + "step 120/157 - loss: 8.9143e-04 - acc_top1: 0.9820 - acc_top2: 0.9965 - 5ms/step\n", + "step 130/157 - loss: 0.0119 - acc_top1: 0.9833 - acc_top2: 0.9968 - 5ms/step\n", + "step 140/157 - loss: 6.7999e-04 - acc_top1: 0.9844 - acc_top2: 0.9970 - 5ms/step\n", + "step 150/157 - loss: 0.0047 - acc_top1: 0.9853 - acc_top2: 0.9972 - 5ms/step\n", + "step 157/157 - loss: 1.6522e-04 - acc_top1: 0.9847 - acc_top2: 0.9973 - 5ms/step\n", + "Eval samples: 10000\n", + "save checkpoint at mnist_checkpoint/final\n" + ] + } + ], + "source": [ + "model.fit(train_dataset,\n", + " test_dataset,\n", + " epochs=2,\n", + " batch_size=64,\n", + " save_dir='mnist_checkpoint')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 组网&训练方式1结束\n", + "以上就是组网&训练方式1,可以非常快速的完成网络模型的构建与训练。此外,paddle还可以用下面的方式来完成模型的训练。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3.组网&训练方式2\n", + "方式1可以快速便捷的完成组网&训练,将细节都隐藏了起来。而方式2则可以用最基本的方式,完成模型的组网与训练。具体如下。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 通过继承Layer的方式来构建模型" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "class LeNet(paddle.nn.Layer):\n", + " def __init__(self):\n", + " super(LeNet, self).__init__()\n", + " self.conv1 = paddle.nn.Conv2D(num_channels=1, num_filters=6, filter_size=5, stride=1, padding=2, act='relu')\n", + " self.max_pool1 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", + " self.conv2 = paddle.nn.Conv2D(num_channels=6, num_filters=16, filter_size=5, stride=1, act='relu')\n", + " self.max_pool2 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", + " self.linear1 = paddle.nn.Linear(input_dim=16*5*5, output_dim=120, act='relu')\n", + " self.linear2 = paddle.nn.Linear(input_dim=120, output_dim=84, act='relu')\n", + " self.linear3 = paddle.nn.Linear(input_dim=84, output_dim=10,act='softmax')\n", + "\n", + " def forward(self, x):\n", + " x = self.conv1(x)\n", + " x = self.max_pool1(x)\n", + " x = self.conv2(x)\n", + " x = self.max_pool2(x)\n", + " x = paddle.reshape(x, shape=[-1, 16*5*5])\n", + " x = self.linear1(x)\n", + " x = self.linear2(x)\n", + " x = self.linear3(x)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 训练模型" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 0, batch_id: 0, loss is: [2.2982373], acc is: [0.15625]\n", + "epoch: 0, batch_id: 100, loss is: [0.25794172], acc is: [0.96875]\n", + "epoch: 0, batch_id: 200, loss is: [0.25025752], acc is: [0.984375]\n", + "epoch: 0, batch_id: 300, loss is: [0.17673397], acc is: [0.984375]\n", + "epoch: 0, batch_id: 400, loss is: [0.09535598], acc is: [1.]\n", + "epoch: 0, batch_id: 500, loss is: [0.08496016], acc is: [1.]\n", + "epoch: 0, batch_id: 600, loss is: [0.14111154], acc is: [0.984375]\n", + "epoch: 0, batch_id: 700, loss is: [0.07322718], acc is: [0.984375]\n", + "epoch: 0, batch_id: 800, loss is: [0.2417614], acc is: [0.984375]\n", + "epoch: 0, batch_id: 900, loss is: [0.10721541], acc is: [1.]\n", + "epoch: 1, batch_id: 0, loss is: [0.02449418], acc is: [1.]\n", + "epoch: 1, batch_id: 100, loss is: [0.151768], acc is: [0.984375]\n", + "epoch: 1, batch_id: 200, loss is: [0.06956144], acc is: [0.984375]\n", + "epoch: 1, batch_id: 300, loss is: [0.2008793], acc is: [1.]\n", + "epoch: 1, batch_id: 400, loss is: [0.03839134], acc is: [1.]\n", + "epoch: 1, batch_id: 500, loss is: [0.0217573], acc is: [1.]\n", + "epoch: 1, batch_id: 600, loss is: [0.10977131], acc is: [0.984375]\n", + "epoch: 1, batch_id: 700, loss is: [0.02774046], acc is: [1.]\n", + "epoch: 1, batch_id: 800, loss is: [0.13530938], acc is: [0.984375]\n", + "epoch: 1, batch_id: 900, loss is: [0.0282761], acc is: [1.]\n" + ] + } + ], + "source": [ + "import paddle\n", + "train_loader = paddle.io.DataLoader(train_dataset, places=paddle.CPUPlace(), batch_size=64)\n", + "def train(model):\n", + " model.train()\n", + " epochs = 2\n", + " batch_size = 64\n", + " optim = paddle.optimizer.Adam(learning_rate=0.001, parameter_list=model.parameters())\n", + " for epoch in range(epochs):\n", + " for batch_id, data in enumerate(train_loader()):\n", + " x_data = data[0]\n", + " y_data = data[1]\n", + " predicts = model(x_data)\n", + " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", + " acc = paddle.metric.accuracy(predicts, y_data, k=2)\n", + " avg_loss = paddle.mean(loss)\n", + " avg_acc = paddle.mean(acc)\n", + " avg_loss.backward()\n", + " if batch_id % 100 == 0:\n", + " print(\"epoch: {}, batch_id: {}, loss is: {}, acc is: {}\".format(epoch, batch_id, avg_loss.numpy(), avg_acc.numpy()))\n", + " optim.minimize(avg_loss)\n", + " model.clear_gradients()\n", + "model = LeNet()\n", + "train(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 对模型进行验证" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "batch_id: 0, loss is: [0.0054796], acc is: [1.]\n", + "batch_id: 100, loss is: [0.12248081], acc is: [0.984375]\n", + "batch_id: 200, loss is: [0.06583288], acc is: [1.]\n", + "batch_id: 300, loss is: [0.07927508], acc is: [1.]\n", + "batch_id: 400, loss is: [0.02623187], acc is: [1.]\n", + "batch_id: 500, loss is: [0.02039231], acc is: [1.]\n", + "batch_id: 600, loss is: [0.03374948], acc is: [1.]\n", + "batch_id: 700, loss is: [0.05141395], acc is: [1.]\n", + "batch_id: 800, loss is: [0.1005884], acc is: [1.]\n", + "batch_id: 900, loss is: [0.03581202], acc is: [1.]\n" + ] + } + ], + "source": [ + "import paddle\n", + "test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64)\n", + "def test(model):\n", + " model.eval()\n", + " batch_size = 64\n", + " for batch_id, data in enumerate(train_loader()):\n", + " x_data = data[0]\n", + " y_data = data[1]\n", + " predicts = model(x_data)\n", + " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", + " acc = paddle.metric.accuracy(predicts, y_data, k=2)\n", + " avg_loss = paddle.mean(loss)\n", + " avg_acc = paddle.mean(acc)\n", + " avg_loss.backward()\n", + " if batch_id % 100 == 0:\n", + " print(\"batch_id: {}, loss is: {}, acc is: {}\".format(batch_id, avg_loss.numpy(), avg_acc.numpy()))\n", + "test(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 组网&训练方式2结束\n", + "以上就是组网&训练方式2,通过这种方式,可以清楚的看到训练和测试中的每一步过程。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 总结\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "以上就是用LeNet对手写数字数据及MNIST进行分类。本示例提供了两种训练模型的方式,一种可以快速完成模型的组建与预测,非常适合新手用户上手。另一种则需要多个步骤来完成模型的训练,适合进阶用户使用。" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/paddle2.0_docs/n_gram_model/n_gram_model.ipynb b/paddle2.0_docs/n_gram_model/n_gram_model.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..1e8ab412038461fffd80103059bc0560f4f90b38 --- /dev/null +++ b/paddle2.0_docs/n_gram_model/n_gram_model.ipynb @@ -0,0 +1,344 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 用N-Gram模型在莎士比亚诗中训练word embedding\n", + "N-gram 是计算机语言学和概率论范畴内的概念,是指给定的一段文本中N个项目的序列。\n", + "N=1 时 N-gram 又称为 unigram,N=2 称为 bigram,N=3 称为 trigram,以此类推。实际应用通常采用 bigram 和 trigram 进行计算。\n", + "本示例在莎士比亚十四行诗上实现了trigram。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 环境\n", + "本教程基于paddle2.0-alpha编写,如果您的环境不是本版本,请先安装paddle2.0-alpha。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'2.0.0-alpha0'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import paddle\n", + "paddle.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 数据集&&相关参数\n", + "训练数据集采用了莎士比亚十四行诗,CONTEXT_SIZE设为2,意味着是trigram。EMBEDDING_DIM设为10。" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [], + "source": [ + "CONTEXT_SIZE = 2\n", + "EMBEDDING_DIM = 10\n", + "\n", + "test_sentence = \"\"\"When forty winters shall besiege thy brow,\n", + "And dig deep trenches in thy beauty's field,\n", + "Thy youth's proud livery so gazed on now,\n", + "Will be a totter'd weed of small worth held:\n", + "Then being asked, where all thy beauty lies,\n", + "Where all the treasure of thy lusty days;\n", + "To say, within thine own deep sunken eyes,\n", + "Were an all-eating shame, and thriftless praise.\n", + "How much more praise deserv'd thy beauty's use,\n", + "If thou couldst answer 'This fair child of mine\n", + "Shall sum my count, and make my old excuse,'\n", + "Proving his beauty by succession thine!\n", + "This were to be new made when thou art old,\n", + "And see thy blood warm when thou feel'st it cold.\"\"\".split()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 数据预处理\n", + "将文本被拆成了元组的形式,格式为(('第一个词', '第二个词'), '第三个词');其中,第三个词就是我们的目标。" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[(('When', 'forty'), 'winters'), (('forty', 'winters'), 'shall'), (('winters', 'shall'), 'besiege')]\n" + ] + } + ], + "source": [ + "trigram = [((test_sentence[i], test_sentence[i + 1]), test_sentence[i + 2])\n", + " for i in range(len(test_sentence) - 2)]\n", + "\n", + "vocab = set(test_sentence)\n", + "word_to_idx = {word: i for i, word in enumerate(vocab)}\n", + "idx_to_word = {word_to_idx[word]: word for word in word_to_idx}\n", + "# 看一下数据集\n", + "print(trigram[:3])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 构建`Dataset`类 加载数据" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "class TrainDataset(paddle.io.Dataset):\n", + " def __init__(self, tuple_data, vocab):\n", + " self.tuple_data = tuple_data\n", + " self.vocab = vocab\n", + "\n", + " def __getitem__(self, idx):\n", + " data = list(self.tuple_data[idx][0])\n", + " label = list(self.tuple_data[idx][1])\n", + " return data, label\n", + " \n", + " def __len__(self):\n", + " return len(self.tuple_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 组网&训练\n", + "这里用paddle动态图的方式组网,由于是N-Gram模型,只需要一层`Embedding`与两层`Linear`就可以完成网络模型的构建。" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "import numpy as np\n", + "class NGramModel(paddle.nn.Layer):\n", + " def __init__(self, vocab_size, embedding_dim, context_size):\n", + " super(NGramModel, self).__init__()\n", + " self.embedding = paddle.nn.Embedding(size=[vocab_size, embedding_dim])\n", + " self.linear1 = paddle.nn.Linear(context_size * embedding_dim, 128)\n", + " self.linear2 = paddle.nn.Linear(128, vocab_size)\n", + "\n", + " def forward(self, x):\n", + " x = self.embedding(x)\n", + " x = paddle.reshape(x, [1, -1])\n", + " x = self.linear1(x)\n", + " x = paddle.nn.functional.relu(x)\n", + " x = self.linear2(x)\n", + " x = paddle.nn.functional.softmax(x)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 初始化Model,并定义相关的参数。" + ] + }, + { + "cell_type": "code", + "execution_count": 121, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 0, loss is: [4.631529]\n", + "epoch: 50, loss is: [4.6081576]\n", + "epoch: 100, loss is: [4.600631]\n", + "epoch: 150, loss is: [4.603069]\n", + "epoch: 200, loss is: [4.592647]\n", + "epoch: 250, loss is: [4.5626693]\n", + "epoch: 300, loss is: [4.513106]\n", + "epoch: 350, loss is: [4.4345813]\n", + "epoch: 400, loss is: [4.3238697]\n", + "epoch: 450, loss is: [4.1728854]\n", + "epoch: 500, loss is: [3.9622664]\n", + "epoch: 550, loss is: [3.67673]\n", + "epoch: 600, loss is: [3.2998457]\n", + "epoch: 650, loss is: [2.8206367]\n", + "epoch: 700, loss is: [2.2514927]\n", + "epoch: 750, loss is: [1.6479329]\n", + "epoch: 800, loss is: [1.1147357]\n", + "epoch: 850, loss is: [0.73231363]\n", + "epoch: 900, loss is: [0.49481753]\n", + "epoch: 950, loss is: [0.3504072]\n" + ] + } + ], + "source": [ + "vocab_size = len(vocab)\n", + "embedding_dim = 10\n", + "context_size = 2\n", + "\n", + "paddle.enable_imperative()\n", + "losses = []\n", + "def train(model):\n", + " model.train()\n", + " optim = paddle.optimizer.SGD(learning_rate=0.001, parameter_list=model.parameters())\n", + " for epoch in range(1000):\n", + " # 留最后10组作为预测\n", + " for context, target in trigram[:-10]:\n", + " context_idxs = list(map(lambda w: word_to_idx[w], context))\n", + " x_data = paddle.imperative.to_variable(np.array(context_idxs))\n", + " y_data = paddle.imperative.to_variable(np.array([word_to_idx[target]]))\n", + " predicts = model(x_data)\n", + " # print (predicts)\n", + " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", + " loss.backward()\n", + " optim.minimize(loss)\n", + " model.clear_gradients()\n", + " if epoch % 50 == 0:\n", + " print(\"epoch: {}, loss is: {}\".format(epoch, loss.numpy()))\n", + " losses.append(loss.numpy())\n", + "model = NGramModel(vocab_size, embedding_dim, context_size)\n", + "train(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 打印loss下降曲线\n", + "通过可视化loss的曲线,可以看到模型训练的效果。" + ] + }, + { + "cell_type": "code", + "execution_count": 123, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 123, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAdkklEQVR4nO3dd3xV9f3H8dfn3uwBmUAggRghMmUFDQ5Qq4JC3VqtVmuttL/aqq3WamtdtcPaarWOuq1aR0XrAJVaxKKIIghE9oaAAQIBAiE7398fuVhUJCHk5tzxfj4e95Hcc06S98nJ451zv/cMc84hIiKhy+d1ABER2T8VtYhIiFNRi4iEOBW1iEiIU1GLiIS4mGB806ysLJefnx+Mby0iEpHmzJmzxTmXva95QSnq/Px8Zs+eHYxvLSISkcxs7dfN09CHiEiIU1GLiIQ4FbWISIhTUYuIhDgVtYhIiFNRi4iEOBW1iEiIC8px1G318PSVNDmI8Vnzw+8jPsZHVko82anxZKXEk5kSR6xf/19EJHqEVFHf9fYyauqbWlwuIzmOrJQ4slPjyU5pLvCs1HhS4mNIjPWTGOcnIdZHjM+H32dffJi1OM3s4NclxucjMbY5h7XHNxSRqBVSRV1y8xgampqob3Q0NjkaGpuorm9ky646tuyqpXxn7Vc+zlm3jfKdta0qeK8kxvpJivOTEPiYGOf//B/K3tPTEuPITIkjMyWerOTmj5kpcaQnxeH3qexFolVIFXVcjI+4fQyb98pM3u/XOeeoqmtkd20DNfXN5V5d30hDYxONTY5G11z8X3nsa7pztMdNb5r/yQSy1DVQXd/I7rpGagIfq+saqaiqY8O2wPP6RnZU19PY9NUf7rPmVxGZyfF06RRPbnoiuelJ9EhLJDc9kR7piXRJTVCZi0SokCrqtjIzUuJjSIkP79VpanLsqK5na1Xt568itu6qY+uuWrZU1bFlZy2bKmt4u6ySLbvqvvC1sX4jp3OguNMCRZ6eGCj1RLp3TsSnIhcJS+HdbBHG5zPSk+NIT46jd5f9L1td18iG7dWs37Y78LGaDduan09fXs7mnbVfeGWQGOunsFsqfbum0jcnlcO6pdI/pxNpSXHBXSkROWgq6jCVGOend5cUendJ2ef82oZGyrbXsGF7NaUVu1m2aRdLNlby9uJNvDC79PPl8jOTGJyXxuDcNAbnpTGgeycSYv0dtRoi0goq6ggVH+MnPyuZ/Kwvju875yjfVcuSsp0s+GwHJaU7mLW6glfnfQY0HxrZL6cTg/M6MyI/gxH5GXRPS/RiFUQkwFx7vHP2JUVFRU7Xow4vmyprmFe6nfml25m/fjvzS3ewq7YBgNz0RI7Iz2DEIRkc0zuLvIwkj9OKRB4zm+OcK9rnPBW17Etjk2NxWSWzVlfw8ZoKZq2uYGtV8xuYBdnJjC7MZlRhNiMLMjVUItIOVNRy0JxzrCzfxfRlW/jvsnI+XLWV2oYmEmP9HN83mzEDunFC3y6kJsR6HVUkLKmopd3V1Dfy4aqt/GfxJqYs3ET5zlri/D6O6ZPFqYNyOKl/VzonqrRFWktFLUHV1OSYW7qNtxZs5M0FG1m/rZo4v49RhVmMOzyHE/t11Z62SAtU1NJhnHPMX7+DySWfMbmkjM921BAX42N0YTbjD8/hG/26hv2JSSLBoKIWTzTvaW9nckkZb3xaxsbKGuJjfIwd2I1zhudy1KFZOu1dJEBFLZ5ranLMWbeNV+dt4LV5n1FZ00BO5wTOGtaDc4bncUjW/q/nIhLpVNQSUmrqG5m6eDMT55Ty32XlNDkYVZjNpUfnM7pPtq5JIlFJRS0ha1NlDS98XMozH65l885aCrKSueSofM4enquxbIkqKmoJeXUNTby5oIzHZ6xhful2UhNiuHhkLy47poCMZF04SiKfilrCytx123jkvVW8uWAjCTF+LiruyeWjCuiSmuB1NJGgUVFLWFq+aScPvLuSV+dtIMbv4zvFvbji+N7aw5aIpKKWsLZmSxX3TVvBy5+sJzkuhh+MLuB7xxxCUpzGsCVyqKglIizbtJM7pyzl7UWbyE6N5/qxfTlzaA8dJSIRYX9F/dUbFIqEqMKuqTxycRETfziS7mmJXPPifM59aCYLNuzwOppIUKmoJewU5Wfwr/87ij+efThrtlRx2n3vc+Mrn7J9d13LXywShlTUEpZ8PuO8EXm8c+1xXDwyn2c/Wsfxf3qXf35cSjCG80S8pKKWsNY5MZZbThvA5CuPpU+XVK57qYTvPDaL0ordXkcTaTetLmoz85vZXDObFMxAIm3RL6cTz08o5vYzBjJ33TZOvns6T8xYTVOT9q4l/B3IHvVVwOJgBRE5WD6fcVFxL/79s9EccUgGt76+iHMfmsmKzbu8jiZyUFpV1GaWC4wDHg1uHJGD1yMtkScvHcFd5w1mZfkuTr33Pe6ftoL6xiavo4m0SWv3qP8CXAd87V+6mU0ws9lmNru8vLw9som0mZlx1rBc3v7paE7s14U7pyzl7Ac/YPWWKq+jiRywFovazMYDm51zc/a3nHPuYedckXOuKDs7u90CihyM7NR4HrhwOA9cOIy1W3cz7t73mDhnvY4MkbDSmj3qo4HTzGwN8Dxwgpk9E9RUIu3s1EE5vHnVsQzq0ZlrX5zPlc/PY2dNvdexRFqlxaJ2zt3gnMt1zuUD5wPvOOcuCnoykXbWPS2RZy8v5tqTC5lc8hmn3TeDJRsrvY4l0iIdRy1Rxe8zfnxCH569vJhdtQ2ccf8MXpxd6nUskf06oKJ2zr3rnBsfrDAiHaW4IJPJVx7D0Lx0fj6xhF+/skBHhUjI0h61RK0uqQk88/0j+cGoAp7+cC3ffWIWO3Zr3FpCj4paoprfZ9xwaj/uPOdwZq2u4IwHZrCyXCfISGhRUYsA5xbl8dzlxVRW13Pm/TP4cNVWryOJfE5FLRJQlJ/BK1ccTZdOCVz8+CzeXrTJ60gigIpa5AvyMpJ48Qcj6ZfTiR8+M4eJc9Z7HUlERS3yZenJcTz7/SM56tBMrn1xPn//YI3XkSTKqahF9iE5PoZHLyni5P5dufm1hTwxY7XXkSSKqahFvkZ8jJ/7LxzGmAFdufX1RTz+vspavKGiFtmPWL+P+749jFMGduO2SYt49L1VXkeSKKSiFmlBrN/HvRcM5dRB3bh98mKe1DCIdLAYrwOIhINYv497zh9KQ+Mn3PL6IpLjYzi3KM/rWBIltEct0kqxfh9//fZQju2TxS9eKmFySZnXkSRKqKhFDkB8jJ+HvjOcYT3TufqFuUxbstnrSBIFVNQiBygpLobHLx3BYd1S+eEzc3S6uQSdilqkDTolxPL3S48gLyOJy578mHml272OJBFMRS3SRpkp8Txz2ZFkpMRxyeOzWLpxp9eRJEKpqEUOQrfOCTz7/WLiY3xc+sQsNlXWeB1JIpCKWuQg5WUk8fh3R7C9up7vPfkxVbUNXkeSCKOiFmkHA3t05v5vD2NxWSU/eW4uDbqtl7QjFbVIOzm+bxduO30g7yzZzC2vL8Q553UkiRA6M1GkHV1U3IvSbbt56L+r6JmRxIRRh3odSSKAilqknf1iTF/WV1TzuzeW0CszmTEDunkdScKchj5E2pnPZ/z5vMEMzu3MNf+cz4rNOmxPDo6KWiQIEmL9PHjRcOJjfEx4eg47a+q9jiRhTEUtEiTd0xK5/8JhrN26m5/9cz5NTXpzUdpGRS0SRMUFmdw4rh9vL9rEfdNWeB1HwpSKWiTIvntUPmcN7cHd/1nG1MWbvI4jYUhFLRJkZsbvzhpE/5xOXP38PFaV7/I6koQZFbVIB0iIbb6OdYzfuOLZudTUN3odScKIilqkg+SmJ/Hn8wazuKyS2ycv8jqOhBEVtUgHOqFvVyaMKuCZD9cxqeQzr+NImFBRi3Swn485jKE907jhpU9Zu7XK6zgSBlTUIh0s1u/j3vOHYgY/eW4u9brSnrRARS3igbyMJO44+3BK1u/gr+/o+GrZPxW1iEdOGZTDWcN6cP+0FXyybpvXcSSEqahFPHTLaQPo1imBn70wj911ujOM7FuLRW1mCWY2y8zmm9lCM7u1I4KJRINOCbH8+bzBrK3YzW8nL/Y6joSo1uxR1wInOOcGA0OAsWZWHNRUIlGkuCCTy48t4B8frWPaks1ex5EQ1GJRu2Z7znmNDTx0GTCRdnTNyYX07ZbKzyeWUFFV53UcCTGtGqM2M7+ZzQM2A2875z7axzITzGy2mc0uLy9v55gikS0+xs/d3xpCZXU9N7xcovstyhe0qqidc43OuSFALnCEmQ3cxzIPO+eKnHNF2dnZ7RxTJPL1y+nEtWMKmbJwE/+au8HrOBJCDuioD+fcdmAaMDYoaUSi3GXHFDC8Vzq3TVpE+c5ar+NIiGjNUR/ZZpYW+DwROAlYEuRcIlHJ7zPuOHsQu2sbueW1hV7HkRDRmj3qHGCamZUAH9M8Rj0puLFEolfvLqlcdWIfJn9axlsLNnodR0JATEsLOOdKgKEdkEVEAiaMKmBySRm/fnUBIwsy6ZwU63Uk8ZDOTBQJQbF+H38853Aqqur4ja5dHfVU1CIhamCPzvxwdAET56xn+jId8hrNVNQiIewnJ/Th0Oxkbnj5U6pqdS2QaKWiFglhCbF+7jj7cDZsr+beqcu9jiMeUVGLhLii/AzOH5HHY++vZunGnV7HEQ+oqEXCwHVj+5KSEMOvX1mg08ujkIpaJAxkJMdx/di+zFpTwcuf6PTyaKOiFgkT5xXlMbRnGr9/czE7dtd7HUc6kIpaJEz4fMbtZwykoqqOP/17qddxpAOpqEXCyIDunbl4ZD7PfLSWkvXbvY4jHURFLRJmfnZyIVkp8dz4ygIam/TGYjRQUYuEmU4Jsdw4rh8l63fw7Kx1XseRDqCiFglDpw3uzlGHZnLnW0vYskvXrY50KmqRMGRm3Hb6QKrrG/n9G7o8fKRTUYuEqd5dUrj82AJe+mQ9H63a6nUcCSIVtUgY+8kJfeiRlshNry6kobHJ6zgSJCpqkTCWGOfn1+P7sXTTTp77uNTrOBIkKmqRMDdmQDeKCzK4699LdcZihFJRi4Q5M+Om8QPYUV3PX6Yu8zqOBIGKWiQC9O/eifOP6MlTM9eyYrMuhRppVNQiEeKakwpJivNz26TFuhRqhFFRi0SIzJR4rvpGH6YvK2fa0s1ex5F2pKIWiSAXj8ynIDuZ2yctpq5Bh+tFChW1SASJi/Hx63H9WbWliqc/XOt1HGknKmqRCHN83y4c2yeLe6cu1+F6EUJFLRKBfnlqPypr6rlvmu5cHglU1CIRqF9OJ84dnsvfP1jLuq27vY4jB0lFLRKhrjn5MPw+444purpeuFNRi0Sorp0SmDCqgMklZcxZu83rOHIQVNQiEWzCqAKyU+P53Rs6CSacqahFIlhyfAzXnFTInLXbeGvBRq/jSBupqEUi3LlFeRzWNZU/vLVEJ8GEKRW1SITz+4xfjuvH2q27dRJMmFJRi0SB0YXZOgkmjKmoRaKEToIJXypqkSihk2DCV4tFbWZ5ZjbNzBaZ2UIzu6ojgolI+9NJMOGpNXvUDcA1zrn+QDFwhZn1D24sEQkGnQQTnlosaudcmXPuk8DnO4HFQI9gBxOR4NBJMOHngMaozSwfGAp8tI95E8xstpnNLi8vb6d4ItLe9j4J5k2dBBMWWl3UZpYCvARc7Zyr/PJ859zDzrki51xRdnZ2e2YUkXb2+Ukwb+okmHDQqqI2s1iaS/ofzrmXgxtJRIJtz0kw6yp289TMNV7HkRa05qgPAx4DFjvn7gp+JBHpCKMLsxldmM09U5dTUVXndRzZj9bsUR8NfAc4wczmBR6nBjmXiHSAX43rR1VtA/dO1UkwoSympQWcc+8D1gFZRKSDFXZN5YIjevL0h2u5qLgXvbukeB1J9kFnJopEuZ+eVEhSrJ/fv7HY6yjyNVTUIlEuKyWeK07ozdQlm3l/+Rav48g+qKhFhO8elU9ueiK3T15EY5NOggk1KmoRISHWz/Wn9GXJxp1MnFPqdRz5EhW1iAAwblAOw3ulc+eUZeyqbfA6juxFRS0iAJgZN47rx5Zdtfzt3ZVex5G9qKhF5HNDe6Zz+pDuPPLeKjZsr/Y6jgSoqEXkC64b2xeAP76la1aHChW1iHxBj7REvn/sIbw67zPmrtM1q0OBilpEvuL/jutNVko8t0/WNatDgYpaRL4iJT6Ga09uvmb15E/LvI4T9VTUIrJP5xbl0bdbKr9/YwnVdY1ex4lqKmoR2Se/z7j1tAFs2F7Ng++u8DpOVFNRi8jXOrIgk9OHdOdv01exdmuV13GilopaRPbrl6f2I9Zn/GbSIq+jRC0VtYjsV9dOCVz5jT78Z/Fm3lmyyes4UUlFLSItuvToQyjITubW1xdRU683FjuailpEWhQX4+OWbw5g7dbdPPb+aq/jRB0VtYi0yqjCbMYO6MZ976zgM10HpEOpqEWk1W4c348m5/itbtvVoVTUItJquelJXHF8byaXlDF9WbnXcaKGilpEDsiEUQUUZCXzq1c+1RmLHURFLSIHJCHWz2/PHERpRTX3TF3udZyooKIWkQM28tBMzh2eyyPvrWJxWaXXcSKeilpE2uSXp/ajc2IsN7z8qe5cHmQqahFpk/TkOG4a3595pdt55sO1XseJaCpqEWmz04d059g+Wdzx1hJKK3Z7HSdiqahFpM3MjN+fNQifGddNLKFJQyBBoaIWkYOSm57EjeP6MXPVVp7WEEhQqKhF5KB9a0Qeowuz+cObS1izRdetbm8qahE5aGbGH84eRIzf+PnE+RoCaWcqahFpFzmdE7n5mwP4eM02Hn1/lddxIoqKWkTazdnDejBmQFfunLKUBRt2eB0nYqioRaTdmBl/OOtwMpPjufK5uVTVNngdKSKoqEWkXaUnx3H3t4awemsVt72u+yy2BxW1iLS7kYdm8qPjDuWF2aVMLinzOk7Ya7GozexxM9tsZgs6IpCIRIarTyxkaM80fvFSCSvLd3kdJ6y1Zo/6SWBskHOISISJ9fu4/9vDiI/x8cOn57BL49Vt1mJRO+emAxUdkEVEIkz3tET+esFQVpbv4rqJ83FOx1e3hcaoRSSojuqdxfWn9OWNTzfy8HQdX90W7VbUZjbBzGab2ezyct1LTUT+5/JjCxg3KIc73lrCjBVbvI4TdtqtqJ1zDzvnipxzRdnZ2e31bUUkApgZfzzncHp3SeFH//hEby4eIA19iEiHSI6P4bFLRhDrNy594mO27qr1OlLYaM3hec8BM4HDzGy9mV0W/FgiEonyMpJ45OIiNlXW8P2nZlNTr7uYt0Zrjvq4wDmX45yLdc7lOuce64hgIhKZhvZM557zhzCvdDtXPT+XhsYmryOFPA19iEiHGzswh5vG92fKwk26M0wrxHgdQESi06VHH0JVbQN/+vcyEuL8/PaMgZiZ17FCkopaRDxzxfG9qapr5MF3V5IY6+fGcf1U1vugohYRz5gZ1405jOq6Rh57fzVNznHT+P4q6y9RUYuIp8yMm7/ZHzN4YsYaqusa+e2Zg/D7VNZ7qKhFxHNmxk3j+5McF8N901awu66RP583mFi/jncAFbWIhAgz49oxh5EU7+ePby1la1UtD1w4nM6JsV5H85z+XYlISPnRcb3507mDmbW6grMf/IDSit1eR/KcilpEQs45w3N56ntHsrmyhjMfmMGctdu8juQpFbWIhKSRh2by8o+OJjk+hvMfnsmTM1ZH7fWsVdQiErJ6d0nhtSuOYXRhNre8vogrn58XlXc2V1GLSEjrnBTLw98p4udjDmNyyWeM/+v7zF0XXUMhKmoRCXk+n3HF8b159vJi6hqaOOdvM7nr7WXUR8kFnVTUIhI2igsyefPqYzl9SHfunbqcsx74gAUbdngdK+hU1CISVjolxHLXeUN48MJhlO2o4bT73ufW1xeys6be62hBo6IWkbB0yqAcpl4zmguP7MWTH6zhxLv+y6vzNkTkJVNV1CIStjonxvKbMwbyrx8dTVZKPFc9P4/T75/BBxF2A10VtYiEvSF5abz+42O467zBVFTV8e1HP+K7T8yKmPFrC8YB5EVFRW727Nnt/n1FRFpSU9/IUzPXcN87K6isaeD4w7L58Qm9Gd4rw+to+2Vmc5xzRfucp6IWkUhUWVPP0zPX8tj7q6moqqO4IIMJowo4rrALvhC8hKqKWkSi1u66Bp6bVcoj01exsbKGnhlJXFTck/OK8khLivM63udU1CIS9eobm5iycCNPzVzLrNUVxMf4GH94d84e3oPiQzI938tWUYuI7GVxWSVPf7iW1+Z9xq7aBnqkJXLG0O6cOTSX3l1SPMmkohYR2YfqukbeXryJlz9Zz/Rl5TQ5KOyawtgB3RgzsBv9czp12P0bVdQiIi3YXFnD5E/LmLJwI7NWV9DkIC8jkRP7dWVUYTbFh2SSGOcP2s9XUYuIHIAtu2r5z6JNvLVwIzNXbqW2oYm4GB9H5GcwqjCLo3tn0bdbp3a9Aa+KWkSkjWrqG5m1uoLpy8qZvrycZZt2AZCaEMPwXumMyM9gRH4Gh+d2JiG27Xvc+ytq3dxWRGQ/EmL9jCrMZlRhNgAbd9Tw4aqtzFpTwcerK3h36VIA4vw+huSl8fyE4nY/gkRFLSJyALp1TuCMoT04Y2gPALZV1TF77TZmr6lgR3V9UA7zU1GLiByE9OQ4TurflZP6dw3az9BFmUREQpyKWkQkxKmoRURCnIpaRCTEqahFREKcilpEJMSpqEVEQpyKWkQkxAXlWh9mVg6sbeOXZwGRdQvhlmmdo4PWOfIdzPr2cs5l72tGUIr6YJjZ7K+7MEmk0jpHB61z5AvW+mroQ0QkxKmoRURCXCgW9cNeB/CA1jk6aJ0jX1DWN+TGqEVE5ItCcY9aRET2oqIWEQlxIVPUZjbWzJaa2Qozu97rPO3FzPLMbJqZLTKzhWZ2VWB6hpm9bWbLAx/TA9PNzO4N/B5KzGyYt2vQdmbmN7O5ZjYp8PwQM/sosG4vmFlcYHp84PmKwPx8T4O3kZmlmdlEM1tiZovNbGSkb2cz+2ng73qBmT1nZgmRtp3N7HEz22xmC/aadsDb1cwuCSy/3MwuOZAMIVHUZuYH7gdOAfoDF5hZf29TtZsG4BrnXH+gGLgisG7XA1Odc32AqYHn0Pw76BN4TAAe7PjI7eYqYPFez+8A7nbO9Qa2AZcFpl8GbAtMvzuwXDi6B3jLOdcXGEzzukfsdjazHsCVQJFzbiDgB84n8rbzk8DYL007oO1qZhnAzcCRwBHAzXvKvVWcc54/gJHAlL2e3wDc4HWuIK3rq8BJwFIgJzAtB1ga+Pwh4IK9lv98uXB6ALmBP+ATgEmA0XzGVsyXtzkwBRgZ+DwmsJx5vQ4HuL6dgdVfzh3J2xnoAZQCGYHtNgkYE4nbGcgHFrR1uwIXAA/tNf0Ly7X0CIk9av63wfdYH5gWUQIv9YYCHwFdnXNlgVkbgT03XIuU38VfgOuApsDzTGC7c64h8Hzv9fp8nQPzdwSWDyeHAOXAE4HhnkfNLJkI3s7OuQ3An4B1QBnN220Okb2d9zjQ7XpQ2ztUijrimVkK8BJwtXOucu95rvlfbMQcJ2lm44HNzrk5XmfpQDHAMOBB59xQoIr/vRwGInI7pwOn0/xPqjuQzFeHCCJeR2zXUCnqDUDeXs9zA9MigpnF0lzS/3DOvRyYvMnMcgLzc4DNgemR8Ls4GjjNzNYAz9M8/HEPkGZmMYFl9l6vz9c5ML8zsLUjA7eD9cB659xHgecTaS7uSN7OJwKrnXPlzrl64GWat30kb+c9DnS7HtT2DpWi/hjoE3i3OI7mNyRe8zhTuzAzAx4DFjvn7tpr1mvAnnd+L6F57HrP9IsD7x4XAzv2eokVFpxzNzjncp1z+TRvy3eccxcC04BzAot9eZ33/C7OCSwfVnuezrmNQKmZHRaY9A1gERG8nWke8ig2s6TA3/medY7Y7byXA92uU4CTzSw98Erk5MC01vF6kH6vwfVTgWXASuBXXudpx/U6huaXRSXAvMDjVJrH5qYCy4H/ABmB5Y3mI2BWAp/S/I665+txEOt/HDAp8HkBMAtYAbwIxAemJwSerwjML/A6dxvXdQgwO7CtXwHSI307A7cCS4AFwNNAfKRtZ+A5msfg62l+5XRZW7Yr8L3Auq8ALj2QDDqFXEQkxIXK0IeIiHwNFbWISIhTUYuIhDgVtYhIiFNRi4iEOBW1iEiIU1GLiIS4/wecQTmUPnjbjwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import matplotlib.ticker as ticker\n", + "%matplotlib inline\n", + "\n", + "plt.figure()\n", + "plt.plot(losses)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 预测\n", + "用训练好的模型进行预测。" + ] + }, + { + "cell_type": "code", + "execution_count": 127, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the input words is: praise., How\n", + "the predict words is: much\n", + "the true words is: much\n" + ] + } + ], + "source": [ + "import random\n", + "def test(model):\n", + " model.eval()\n", + " # 从最后10组数据中随机选取1个\n", + " idx = random.randint(len(trigram)-10, len(trigram)-1)\n", + " print('the input words is: ' + trigram[idx][0][0] + ', ' + trigram[idx][0][1])\n", + " x_data = list(map(lambda w: word_to_idx[w], trigram[idx][0]))\n", + " x_data = paddle.imperative.to_variable(np.array(x_data))\n", + " predicts = model(x_data)\n", + " predicts = predicts.numpy().tolist()[0]\n", + " predicts = predicts.index(max(predicts))\n", + " print('the predict words is: ' + idx_to_word[predicts])\n", + " y_data = trigram[idx][1]\n", + " print('the true words is: ' + y_data)\n", + "test(model)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.7.3 64-bit", + "language": "python", + "name": "python_defaultSpec_1598180286976" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3-final" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/paddle2.0_docs/text_generation/text_generation_paddle.ipynb b/paddle2.0_docs/text_generation/text_generation_paddle.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..fc47a4196ddeab7bdf79c5777f7e0a6a89351f41 --- /dev/null +++ b/paddle2.0_docs/text_generation/text_generation_paddle.ipynb @@ -0,0 +1,526 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 基于GRU的Text Generation\n", + "文本生成是NLP领域中的重要组成部分,基于GRU,我们可以快速构建文本生成模型。" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'2.0.0-alpha0'" + ] + }, + "execution_count": 74, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import paddle\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "paddle.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 复现过程\n", + "## 1.下载数据\n", + "文件路径:https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt\n", + "保存为txt格式即可" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.读取数据" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Length of text: 1115394 characters\n" + ] + } + ], + "source": [ + "# 文件路径\n", + "path_to_file = './shakespeare.txt'\n", + "text = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n", + "\n", + "# 文本长度是指文本中的字符个数\n", + "print ('Length of text: {} characters'.format(len(text)))" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "First Citizen:\n", + "Before we proceed any further, hear me speak.\n", + "\n", + "All:\n", + "Speak, speak.\n", + "\n", + "First Citizen:\n", + "You are all resolved rather to die than to famish?\n", + "\n", + "All:\n", + "Resolved. resolved.\n", + "\n", + "First Citizen:\n", + "First, you know Caius Marcius is chief enemy to the people.\n", + "\n" + ] + } + ], + "source": [ + "# 看一看文本中的前 250 个字符\n", + "print(text[:250])" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "65 unique characters\n" + ] + } + ], + "source": [ + "# 文本中的非重复字符\n", + "vocab = sorted(set(text))\n", + "print ('{} unique characters'.format(len(vocab)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3.向量化文本\n", + "在训练之前,我们需要将字符串映射到数字表示值。创建两个查找表格:一个将字符映射到数字,另一个将数字映射到字符。" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [], + "source": [ + "# 创建从非重复字符到索引的映射\n", + "char2idx = {u:i for i, u in enumerate(vocab)}\n", + "idx2char = np.array(vocab)\n", + "# 用index表示文本\n", + "text_as_int = np.array([char2idx[c] for c in text])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'\\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, \"'\": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}\n" + ] + } + ], + "source": [ + "print(char2idx)" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['\\n' ' ' '!' '$' '&' \"'\" ',' '-' '.' '3' ':' ';' '?' 'A' 'B' 'C' 'D' 'E'\n", + " 'F' 'G' 'H' 'I' 'J' 'K' 'L' 'M' 'N' 'O' 'P' 'Q' 'R' 'S' 'T' 'U' 'V' 'W'\n", + " 'X' 'Y' 'Z' 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' 'k' 'l' 'm' 'n' 'o'\n", + " 'p' 'q' 'r' 's' 't' 'u' 'v' 'w' 'x' 'y' 'z']\n" + ] + } + ], + "source": [ + "print(idx2char)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "现在,每个字符都有一个整数表示值。请注意,我们将字符映射至索引 0 至 len(vocab)." + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[18 47 56 ... 45 8 0]\n", + "1115394\n" + ] + } + ], + "source": [ + "print(text_as_int)\n", + "print(len(text_as_int))" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "'First Citizen' ---- characters mapped to int ---- > [18 47 56 57 58 1 15 47 58 47 64 43 52]\n" + ] + } + ], + "source": [ + "# 显示文本首 13 个字符的整数映射\n", + "print ('{} ---- characters mapped to int ---- > {}'.format(repr(text[:13]), text_as_int[:13]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 预测任务\n", + "给定一个字符或者一个字符序列,下一个最可能出现的字符是什么?这就是我们训练模型要执行的任务。输入进模型的是一个字符序列,我们训练这个模型来预测输出 -- 每个时间步(time step)预测下一个字符是什么。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 创建训练样本和目标\n", + "接下来,将文本划分为样本序列。每个输入序列包含文本中的 seq_length 个字符。\n", + "\n", + "对于每个输入序列,其对应的目标包含相同长度的文本,但是向右顺移一个字符。\n", + "\n", + "将文本拆分为长度为 seq_length 的文本块。例如,假设 seq_length 为 4 而且文本为 “Hello”, 那么输入序列将为 “Hell”,目标序列将为 “ello”。" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [], + "source": [ + "seq_length = 100\n", + "def load_data(data, seq_length):\n", + " train_data = []\n", + " train_label = []\n", + " for i in range(len(data)//seq_length):\n", + " train_data.append(data[i*seq_length:(i+1)*seq_length])\n", + " train_label.append(data[i*seq_length + 1:(i+1)*seq_length+1])\n", + " return train_data, train_label\n", + "train_data, train_label = load_data(text_as_int, seq_length)" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "training data is :\n", + "First Citizen:\n", + "Before we proceed any further, hear me speak.\n", + "\n", + "All:\n", + "Speak, speak.\n", + "\n", + "First Citizen:\n", + "You\n", + "------------\n", + "training_label is:\n", + "irst Citizen:\n", + "Before we proceed any further, hear me speak.\n", + "\n", + "All:\n", + "Speak, speak.\n", + "\n", + "First Citizen:\n", + "You \n" + ] + } + ], + "source": [ + "char_list = []\n", + "label_list = []\n", + "for char_id, label_id in zip(train_data[0], train_label[0]):\n", + " char_list.append(idx2char[char_id])\n", + " label_list.append(idx2char[label_id])\n", + "\n", + "print('training data is :')\n", + "print(''.join(char_list))\n", + "print(\"------------\")\n", + "print('training_label is:')\n", + "print(''.join(label_list))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 用`paddle.batch`完成数据的加载" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "batch_size = 64\n", + "def train_reader():\n", + " for i in range(len(train_data)):\n", + " yield train_data[i], train_label[i]\n", + "batch_reader = paddle.batch(train_reader, batch_size=batch_size) " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 基于GRU构建文本生成模型" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "import numpy as np\n", + "\n", + "vocab_size = len(vocab)\n", + "embedding_dim = 256\n", + "hidden_size = 1024\n", + "class GRUModel(paddle.nn.Layer):\n", + " def __init__(self):\n", + " super(GRUModel, self).__init__()\n", + " self.embedding = paddle.nn.Embedding(size=[vocab_size, embedding_dim])\n", + " self.gru = paddle.incubate.hapi.text.GRU(input_size=embedding_dim, hidden_size=hidden_size)\n", + " self.linear1 = paddle.nn.Linear(hidden_size, hidden_size//2)\n", + " self.linear2 = paddle.nn.Linear(hidden_size//2, vocab_size)\n", + " def forward(self, x):\n", + " x = self.embedding(x)\n", + " x = paddle.reshape(x, [-1, 1, embedding_dim])\n", + " x, _ = self.gru(x)\n", + " x = paddle.reshape(x, [-1, hidden_size])\n", + " x = self.linear1(x)\n", + " x = paddle.nn.functional.relu(x)\n", + " x = self.linear2(x)\n", + " x = paddle.nn.functional.softmax(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 0, batch: 50, loss is: [3.7835407]\n", + "epoch: 0, batch: 100, loss is: [3.2774005]\n", + "epoch: 0, batch: 150, loss is: [3.2576294]\n", + "epoch: 1, batch: 50, loss is: [3.3434656]\n", + "epoch: 1, batch: 100, loss is: [2.9948606]\n", + "epoch: 1, batch: 150, loss is: [3.0285468]\n", + "epoch: 2, batch: 50, loss is: [3.133882]\n", + "epoch: 2, batch: 100, loss is: [2.7811327]\n", + "epoch: 2, batch: 150, loss is: [2.8133557]\n", + "epoch: 3, batch: 50, loss is: [3.000814]\n", + "epoch: 3, batch: 100, loss is: [2.6404488]\n", + "epoch: 3, batch: 150, loss is: [2.7050896]\n", + "epoch: 4, batch: 50, loss is: [2.9289591]\n", + "epoch: 4, batch: 100, loss is: [2.5629177]\n", + "epoch: 4, batch: 150, loss is: [2.6438713]\n", + "epoch: 5, batch: 50, loss is: [2.8832304]\n", + "epoch: 5, batch: 100, loss is: [2.5137548]\n", + "epoch: 5, batch: 150, loss is: [2.5926144]\n", + "epoch: 6, batch: 50, loss is: [2.8562953]\n", + "epoch: 6, batch: 100, loss is: [2.4752126]\n", + "epoch: 6, batch: 150, loss is: [2.5510798]\n", + "epoch: 7, batch: 50, loss is: [2.8426895]\n", + "epoch: 7, batch: 100, loss is: [2.4442513]\n", + "epoch: 7, batch: 150, loss is: [2.5187433]\n", + "epoch: 8, batch: 50, loss is: [2.8353484]\n", + "epoch: 8, batch: 100, loss is: [2.4200597]\n", + "epoch: 8, batch: 150, loss is: [2.4956212]\n", + "epoch: 9, batch: 50, loss is: [2.8308532]\n", + "epoch: 9, batch: 100, loss is: [2.4011066]\n", + "epoch: 9, batch: 150, loss is: [2.4787998]\n" + ] + } + ], + "source": [ + "paddle.enable_imperative()\n", + "losses = []\n", + "def train(model):\n", + " model.train()\n", + " optim = paddle.optimizer.SGD(learning_rate=0.001, parameter_list=model.parameters())\n", + " for epoch in range(10):\n", + " batch_id = 0\n", + " for batch_data in batch_reader():\n", + " batch_id += 1\n", + " data = np.array(batch_data)\n", + " x_data = data[:, 0]\n", + " y_data = data[:, 1]\n", + " for i in range(len(x_data[0])):\n", + " x_char = x_data[:, i]\n", + " y_char = y_data[:, i]\n", + " x_char = paddle.imperative.to_variable(x_char)\n", + " y_char = paddle.imperative.to_variable(y_char)\n", + " predicts = model(x_char)\n", + " loss = paddle.nn.functional.cross_entropy(predicts, y_char)\n", + " avg_loss = paddle.mean(loss)\n", + " avg_loss.backward()\n", + " optim.minimize(avg_loss)\n", + " model.clear_gradients()\n", + " if batch_id % 50 == 0:\n", + " print(\"epoch: {}, batch: {}, loss is: {}\".format(epoch, batch_id, avg_loss.numpy()))\n", + " losses.append(loss.numpy())\n", + "model = GRUModel()\n", + "train(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 模型预测\n", + "利用训练好的模型,输出初始化文本'ROMEO: ',自动生成后续的num_generate个字符。" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ROMEO:I the the the the the the the the the the the the the the the the the the the the the the the the th\n" + ] + } + ], + "source": [ + "def generate_text(model, start_string):\n", + " \n", + " model.eval()\n", + " num_generate = 100\n", + "\n", + " # Converting our start string to numbers (vectorizing)\n", + " input_eval = [char2idx[s] for s in start_string]\n", + " input_data = paddle.imperative.to_variable(np.array(input_eval))\n", + " input_data = paddle.reshape(input_data, [-1, 1])\n", + " text_generated = []\n", + "\n", + " for i in range(num_generate):\n", + " predicts = model(input_data)\n", + " predicts = predicts.numpy().tolist()[0]\n", + " # print(predicts)\n", + " predicts_id = predicts.index(max(predicts))\n", + " # print(predicts_id)\n", + " # using a categorical distribution to predict the character returned by the model\n", + " input_data = paddle.imperative.to_variable(np.array([predicts_id]))\n", + " input_data = paddle.reshape(input_data, [-1, 1])\n", + " text_generated.append(idx2char[predicts_id])\n", + " return (start_string + ''.join(text_generated))\n", + "print(generate_text(model, start_string=u\"ROMEO:\"))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}