Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
book
提交
6cbf7fd8
B
book
项目概览
PaddlePaddle
/
book
通知
16
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
40
列表
看板
标记
里程碑
合并请求
37
Wiki
5
Wiki
分析
仓库
DevOps
项目成员
Pages
B
book
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
40
Issue
40
列表
看板
标记
里程碑
合并请求
37
合并请求
37
Pages
分析
分析
仓库分析
DevOps
Wiki
5
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6cbf7fd8
编写于
8月 26, 2020
作者:
C
chenlong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add three docs for paddle2.0
上级
9fb40dac
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
1536 addition
and
0 deletion
+1536
-0
paddle2.0_docs/image_classification/mnist_lenet_classification.ipynb
...ocs/image_classification/mnist_lenet_classification.ipynb
+666
-0
paddle2.0_docs/n_gram_model/n_gram_model.ipynb
paddle2.0_docs/n_gram_model/n_gram_model.ipynb
+344
-0
paddle2.0_docs/text_generation/text_generation_paddle.ipynb
paddle2.0_docs/text_generation/text_generation_paddle.ipynb
+526
-0
未找到文件。
paddle2.0_docs/image_classification/mnist_lenet_classification.ipynb
0 → 100644
浏览文件 @
6cbf7fd8
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MNIST数据集使用LeNet进行图像分类\n",
"本示例教程演示如何在MNIST数据集上用LeNet进行图像分类。\n",
"手写数字的MNIST数据集,包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到1。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist/"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 环境\n",
"本教程基于paddle2.0-alpha编写,如果您的环境不是本版本,请先安装paddle2.0-alpha。"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.0.0-alpha0\n"
]
}
],
"source": [
"import paddle\n",
"print(paddle.__version__)\n",
"paddle.enable_imperative()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 加载数据集\n",
"我们使用飞桨自带的paddle.dataset完成mnist数据集的加载。"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"download training data and load training data\n",
"load finished\n",
"<class 'numpy.ndarray'>\n"
]
}
],
"source": [
"print('download training data and load training data')\n",
"train_dataset = paddle.incubate.hapi.datasets.MNIST(mode='train')\n",
"test_dataset = paddle.incubate.hapi.datasets.MNIST(mode='test')\n",
"print('load finished')\n",
"print(type(train_dataset[0][0]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"取训练集中的一条数据看一下。"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train_data0 label is: [5]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAIY0lEQVR4nO3dXWhUZxoH8P/jaPxav7KREtNgiooQFvwg1l1cNOr6sQUN3ixR0VUK9cKPXTBYs17ohReLwl5ovCmuZMU1y+IaWpdC0GIuxCJJMLhJa6oWtSl+FVEXvdDK24s5nc5zapKTZ86cOTPz/4Hk/M8xc17w8Z13zpl5RpxzIBquEbkeAOUnFg6ZsHDIhIVDJiwcMmHhkElGhSMiq0WkT0RuisjesAZF8SfW6zgikgDwFYAVAPoBdABY75z7IrzhUVyNzOB33wVw0zn3NQCIyL8A1AEYsHDKyspcVVVVBqekqHV1dX3nnJvq359J4VQA+CYt9wNYONgvVFVVobOzM4NTUtRE5M6b9md9cSwiH4hIp4h0Pnr0KNuno4hkUjjfAqhMy297+xTn3EfOuRrnXM3UqT+b8ShPZVI4HQBmicg7IlICoB7AJ+EMi+LOvMZxzn0vIjsAtAFIADjhnOsNbWQUa5ksjuGc+xTApyGNhfIIrxyTCQuHTFg4ZMLCIRMWDpmwcMiEhUMmLBwyYeGQCQuHTFg4ZMLCIZOMbnIWk9evX6v89OnTwL/b1NSk8osXL1Tu6+tT+dixYyo3NDSo3NLSovKYMWNU3rv3p88N7N+/P/A4h4MzDpmwcMiEhUMmRbPGuXv3rsovX75U+fLlyypfunRJ5SdPnqh85syZ0MZWWVmp8s6dO1VubW1VecKECSrPmTNH5SVLloQ2toFwxiETFg6ZsHDIpGDXOFevXlV52bJlKg/nOkzYEomEygcPHlR5/PjxKm/cuFHladOmqTxlyhSVZ8+enekQh8QZh0xYOGTCwiGTgl3jTJ8+XeWysjKVw1zjLFyom3T41xwXL15UuaSkROVNmzaFNpaocMYhExYOmbBwyKRg1zilpaUqHz58WOVz586pPG/ePJV37do16OPPnTs3tX3hwgV1zH8dpqenR+UjR44M+tj5gDMOmQxZOCJyQkQeikhP2r5SETkvIje8n1MGewwqPEFmnGYAq3379gL4zDk3C8BnXqYiEqjPsYhUAfivc+5XXu4DUOucuyci5QDanXND3iCpqalxcek6+uzZM5X973HZtm2bysePH1f51KlTqe0NGzaEPLr4EJEu51yNf791jfOWc+6et30fwFvmkVFeynhx7JJT1oDTFtvVFiZr4TzwnqLg/Xw40F9ku9rCZL2O8wmAPwL4q/fz49BGFJGJEycOenzSpEmDHk9f89TX16tjI0YU/lWOIC/HWwB8DmC2iPSLyPtIFswKEbkB4HdepiIy5IzjnFs/wKHlIY+F8kjhz6mUFQV7rypTBw4cULmrq0vl9vb21Lb/XtXKlSuzNazY4IxDJiwcMmHhkIn5Ozkt4nSvarhu3bql8vz581PbkydPVseWLl2qck2NvtWzfft2lUUkhBFmR9j3qqjIsXDIhC/HA5oxY4bKzc3Nqe2tW7eqYydPnhw0P3/+XOXNmzerXF5ebh1mZDjjkAkLh0xYOGTCNY7RunXrUtszZ85Ux3bv3q2y/5ZEY2Ojynfu6O+E37dvn8oVFRXmcWYLZxwyYeGQCQuHTHjLIQv8rW39HzfesmWLyv5/g+XL9Xvkzp8/H9rYhou3HChULBwyYeGQCdc4OTB69GiVX716pfKoUaNUbmtrU7m2tjYr43oTrnEoVCwcMmHhkAnvVYXg2rVrKvu/kqijo0Nl/5rGr7q6WuXFixdnMLrs4IxDJiwcMmHhkAnXOAH5v+L56NGjqe2zZ8+qY/fv3x/WY48cqf8Z/O85jmPblPiNiPJCkP44lSJyUUS+EJFeEfmTt58ta4tYkBnnewC7nXPVAH4NYLuIVIMta4takMZK9wDc87b/LyJfAqgAUAeg1vtr/wDQDuDDrIwyAv51yenTp1VuampS+fbt2+ZzLViwQGX/e4zXrl1rfuyoDGuN4/U7ngfgCtiytqgFLhwR+QWA/wD4s3NOdZcerGUt29UWpkCFIyKjkCyafzrnfnztGahlLdvVFqYh1ziS7MHxdwBfOuf+lnYor1rWPnjwQOXe3l6Vd+zYofL169fN5/J/1eKePXtUrqurUzmO12mGEuQC4CIAmwD8T0S6vX1/QbJg/u21r70D4A9ZGSHFUpBXVZcADNT5hy1ri1T+zZEUCwVzr+rx48cq+782qLu7W2V/a7bhWrRoUWrb/1nxVatWqTx27NiMzhVHnHHIhIVDJiwcMsmrNc6VK1dS24cOHVLH/O/r7e/vz+hc48aNU9n/ddLp95f8XxddDDjjkAkLh0zy6qmqtbX1jdtB+D9ysmbNGpUTiYTKDQ0NKvu7pxc7zjhkwsIhExYOmbDNCQ2KbU4oVCwcMmHhkAkLh0xYOGTCwiETFg6ZsHDIhIVDJiwcMmHhkEmk96pE5BGSn/osA/BdZCcenriOLVfjmu6c+9mH/iMtnNRJRTrfdOMsDuI6triNi09VZMLCIZNcFc5HOTpvEHEdW6zGlZM1DuU/PlWRSaSFIyKrRaRPRG6KSE7b24rICRF5KCI9afti0bs5H3pLR1Y4IpIAcAzA7wFUA1jv9UvOlWYAq3374tK7Of69pZ1zkfwB8BsAbWm5EUBjVOcfYExVAHrSch+Acm+7HEBfLseXNq6PAayI0/iifKqqAPBNWu739sVJ7Ho3x7W3NBfHA3DJ/9Y5fclp7S0dhSgL51sAlWn5bW9fnATq3RyFTHpLRyHKwukAMEtE3hGREgD1SPZKjpMfezcDOezdHKC3NJDr3tIRL/LeA/AVgFsA9uV4wdmC5JebvEJyvfU+gF8i+WrlBoALAEpzNLbfIvk0dA1At/fnvbiMzznHK8dkw8UxmbBwyISFQyYsHDJh4ZAJC4dMWDhkwsIhkx8AyyZIbAmqetUAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 144x144 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]\n",
"train_data0 = train_data0.transpose(1,2,0)\n",
"plt.figure(figsize=(2,2))\n",
"plt.imshow(train_data0, cmap=plt.cm.binary)\n",
"print('train_data0 label is: ' + str(train_label_0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2.组网&训练方案1\n",
"paddle支持用model类,直接完成模型的训练,具体如下。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 首先需要继承Model来自定义LeNet网络。"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"class LeNet(paddle.incubate.hapi.model.Model):\n",
" def __init__(self):\n",
" super(LeNet, self).__init__()\n",
" self.conv1 = paddle.nn.Conv2D(num_channels=1, num_filters=6, filter_size=5, stride=1, padding=2, act='relu')\n",
" self.max_pool1 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n",
" self.conv2 = paddle.nn.Conv2D(num_channels=6, num_filters=16, filter_size=5, stride=1, act='relu')\n",
" self.max_pool2 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n",
" self.linear1 = paddle.nn.Linear(input_dim=16*5*5, output_dim=120, act='relu')\n",
" self.linear2 = paddle.nn.Linear(input_dim=120, output_dim=84, act='relu')\n",
" self.linear3 = paddle.nn.Linear(input_dim=84, output_dim=10, act='softmax')\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = self.max_pool1(x)\n",
" x = self.conv2(x)\n",
" x = self.max_pool2(x)\n",
" x = paddle.reshape(x, shape=[-1, 16*5*5])\n",
" x = self.linear1(x)\n",
" x = self.linear2(x)\n",
" x = self.linear3(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 初始化Model,并定义相关的参数。"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"from paddle.incubate.hapi.model import Input\n",
"from paddle.incubate.hapi.loss import CrossEntropy\n",
"from paddle.incubate.hapi.metrics import Accuracy\n",
"\n",
"inputs = [Input([None, 1, 28, 28], 'float32', name='image')]\n",
"labels = [Input([None, 1], 'int64', name='label')]\n",
"model = LeNet()\n",
"optim = paddle.optimizer.Adam(learning_rate=0.001, parameter_list=model.parameters())\n",
"\n",
"model.prepare(\n",
" optim,\n",
" CrossEntropy(),\n",
" Accuracy(topk=(1, 2)),\n",
" inputs=inputs,\n",
" labels=labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 使用fit来训练模型"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"step 10/938 - loss: 2.1912 - acc_top1: 0.2719 - acc_top2: 0.4109 - 16ms/step\n",
"step 20/938 - loss: 1.6389 - acc_top1: 0.4109 - acc_top2: 0.5367 - 15ms/step\n",
"step 30/938 - loss: 1.1486 - acc_top1: 0.4797 - acc_top2: 0.6135 - 15ms/step\n",
"step 40/938 - loss: 0.7755 - acc_top1: 0.5484 - acc_top2: 0.6770 - 15ms/step\n",
"step 50/938 - loss: 0.7651 - acc_top1: 0.5975 - acc_top2: 0.7266 - 15ms/step\n",
"step 60/938 - loss: 0.3837 - acc_top1: 0.6393 - acc_top2: 0.7617 - 15ms/step\n",
"step 70/938 - loss: 0.6532 - acc_top1: 0.6712 - acc_top2: 0.7888 - 15ms/step\n",
"step 80/938 - loss: 0.3394 - acc_top1: 0.6969 - acc_top2: 0.8107 - 15ms/step\n",
"step 90/938 - loss: 0.2527 - acc_top1: 0.7189 - acc_top2: 0.8283 - 15ms/step\n",
"step 100/938 - loss: 0.2055 - acc_top1: 0.7389 - acc_top2: 0.8427 - 14ms/step\n",
"step 110/938 - loss: 0.3987 - acc_top1: 0.7531 - acc_top2: 0.8536 - 14ms/step\n",
"step 120/938 - loss: 0.2372 - acc_top1: 0.7660 - acc_top2: 0.8622 - 14ms/step\n",
"step 130/938 - loss: 0.4071 - acc_top1: 0.7780 - acc_top2: 0.8708 - 14ms/step\n",
"step 140/938 - loss: 0.1315 - acc_top1: 0.7895 - acc_top2: 0.8780 - 14ms/step\n",
"step 150/938 - loss: 0.3168 - acc_top1: 0.7981 - acc_top2: 0.8843 - 15ms/step\n",
"step 160/938 - loss: 0.2782 - acc_top1: 0.8063 - acc_top2: 0.8901 - 15ms/step\n",
"step 170/938 - loss: 0.2030 - acc_top1: 0.8144 - acc_top2: 0.8956 - 15ms/step\n",
"step 180/938 - loss: 0.2336 - acc_top1: 0.8203 - acc_top2: 0.9000 - 15ms/step\n",
"step 190/938 - loss: 0.5915 - acc_top1: 0.8260 - acc_top2: 0.9038 - 15ms/step\n",
"step 200/938 - loss: 0.4995 - acc_top1: 0.8310 - acc_top2: 0.9076 - 15ms/step\n",
"step 210/938 - loss: 0.2190 - acc_top1: 0.8359 - acc_top2: 0.9106 - 15ms/step\n",
"step 220/938 - loss: 0.1835 - acc_top1: 0.8397 - acc_top2: 0.9130 - 15ms/step\n",
"step 230/938 - loss: 0.1321 - acc_top1: 0.8442 - acc_top2: 0.9159 - 15ms/step\n",
"step 240/938 - loss: 0.2406 - acc_top1: 0.8478 - acc_top2: 0.9183 - 15ms/step\n",
"step 250/938 - loss: 0.1245 - acc_top1: 0.8518 - acc_top2: 0.9209 - 15ms/step\n",
"step 260/938 - loss: 0.1570 - acc_top1: 0.8559 - acc_top2: 0.9236 - 15ms/step\n",
"step 270/938 - loss: 0.1647 - acc_top1: 0.8593 - acc_top2: 0.9259 - 15ms/step\n",
"step 280/938 - loss: 0.1876 - acc_top1: 0.8625 - acc_top2: 0.9281 - 14ms/step\n",
"step 290/938 - loss: 0.2247 - acc_top1: 0.8650 - acc_top2: 0.9300 - 15ms/step\n",
"step 300/938 - loss: 0.2070 - acc_top1: 0.8679 - acc_top2: 0.9318 - 15ms/step\n",
"step 310/938 - loss: 0.1122 - acc_top1: 0.8701 - acc_top2: 0.9333 - 14ms/step\n",
"step 320/938 - loss: 0.0857 - acc_top1: 0.8729 - acc_top2: 0.9351 - 14ms/step\n",
"step 330/938 - loss: 0.2414 - acc_top1: 0.8751 - acc_top2: 0.9365 - 14ms/step\n",
"step 340/938 - loss: 0.2631 - acc_top1: 0.8774 - acc_top2: 0.9380 - 14ms/step\n",
"step 350/938 - loss: 0.1347 - acc_top1: 0.8796 - acc_top2: 0.9396 - 14ms/step\n",
"step 360/938 - loss: 0.2295 - acc_top1: 0.8816 - acc_top2: 0.9409 - 14ms/step\n",
"step 370/938 - loss: 0.2971 - acc_top1: 0.8842 - acc_top2: 0.9423 - 14ms/step\n",
"step 380/938 - loss: 0.1623 - acc_top1: 0.8863 - acc_top2: 0.9436 - 14ms/step\n",
"step 390/938 - loss: 0.1020 - acc_top1: 0.8880 - acc_top2: 0.9448 - 14ms/step\n",
"step 400/938 - loss: 0.0716 - acc_top1: 0.8895 - acc_top2: 0.9459 - 14ms/step\n",
"step 410/938 - loss: 0.0889 - acc_top1: 0.8914 - acc_top2: 0.9469 - 14ms/step\n",
"step 420/938 - loss: 0.1010 - acc_top1: 0.8931 - acc_top2: 0.9478 - 14ms/step\n",
"step 430/938 - loss: 0.0486 - acc_top1: 0.8945 - acc_top2: 0.9487 - 14ms/step\n",
"step 440/938 - loss: 0.1723 - acc_top1: 0.8958 - acc_top2: 0.9495 - 14ms/step\n",
"step 450/938 - loss: 0.2270 - acc_top1: 0.8974 - acc_top2: 0.9503 - 14ms/step\n",
"step 460/938 - loss: 0.1197 - acc_top1: 0.8987 - acc_top2: 0.9512 - 14ms/step\n",
"step 470/938 - loss: 0.2837 - acc_top1: 0.9002 - acc_top2: 0.9519 - 14ms/step\n",
"step 480/938 - loss: 0.1091 - acc_top1: 0.9017 - acc_top2: 0.9528 - 14ms/step\n",
"step 490/938 - loss: 0.1397 - acc_top1: 0.9029 - acc_top2: 0.9535 - 14ms/step\n",
"step 500/938 - loss: 0.1034 - acc_top1: 0.9040 - acc_top2: 0.9543 - 14ms/step\n",
"step 510/938 - loss: 0.0095 - acc_top1: 0.9054 - acc_top2: 0.9550 - 14ms/step\n",
"step 520/938 - loss: 0.0092 - acc_top1: 0.9068 - acc_top2: 0.9558 - 14ms/step\n",
"step 530/938 - loss: 0.0633 - acc_top1: 0.9077 - acc_top2: 0.9565 - 14ms/step\n",
"step 540/938 - loss: 0.0936 - acc_top1: 0.9086 - acc_top2: 0.9571 - 14ms/step\n",
"step 550/938 - loss: 0.1180 - acc_top1: 0.9097 - acc_top2: 0.9577 - 14ms/step\n",
"step 560/938 - loss: 0.1600 - acc_top1: 0.9106 - acc_top2: 0.9583 - 14ms/step\n",
"step 570/938 - loss: 0.1338 - acc_top1: 0.9118 - acc_top2: 0.9590 - 14ms/step\n",
"step 580/938 - loss: 0.0496 - acc_top1: 0.9128 - acc_top2: 0.9595 - 14ms/step\n",
"step 590/938 - loss: 0.0651 - acc_top1: 0.9138 - acc_top2: 0.9600 - 14ms/step\n",
"step 600/938 - loss: 0.1306 - acc_top1: 0.9147 - acc_top2: 0.9605 - 14ms/step\n",
"step 610/938 - loss: 0.0744 - acc_top1: 0.9157 - acc_top2: 0.9610 - 14ms/step\n",
"step 620/938 - loss: 0.1679 - acc_top1: 0.9166 - acc_top2: 0.9616 - 14ms/step\n",
"step 630/938 - loss: 0.0789 - acc_top1: 0.9173 - acc_top2: 0.9621 - 14ms/step\n",
"step 640/938 - loss: 0.0767 - acc_top1: 0.9182 - acc_top2: 0.9626 - 14ms/step\n",
"step 650/938 - loss: 0.1776 - acc_top1: 0.9188 - acc_top2: 0.9630 - 14ms/step\n",
"step 660/938 - loss: 0.1371 - acc_top1: 0.9196 - acc_top2: 0.9634 - 14ms/step\n",
"step 670/938 - loss: 0.1011 - acc_top1: 0.9204 - acc_top2: 0.9639 - 14ms/step\n",
"step 680/938 - loss: 0.0447 - acc_top1: 0.9209 - acc_top2: 0.9642 - 14ms/step\n",
"step 690/938 - loss: 0.0230 - acc_top1: 0.9217 - acc_top2: 0.9646 - 14ms/step\n",
"step 700/938 - loss: 0.0541 - acc_top1: 0.9224 - acc_top2: 0.9649 - 14ms/step\n",
"step 710/938 - loss: 0.1395 - acc_top1: 0.9231 - acc_top2: 0.9653 - 14ms/step\n",
"step 720/938 - loss: 0.0426 - acc_top1: 0.9238 - acc_top2: 0.9657 - 14ms/step\n",
"step 730/938 - loss: 0.0540 - acc_top1: 0.9247 - acc_top2: 0.9660 - 14ms/step\n",
"step 740/938 - loss: 0.1132 - acc_top1: 0.9253 - acc_top2: 0.9664 - 14ms/step\n",
"step 750/938 - loss: 0.0088 - acc_top1: 0.9261 - acc_top2: 0.9668 - 14ms/step\n",
"step 760/938 - loss: 0.0282 - acc_top1: 0.9266 - acc_top2: 0.9672 - 14ms/step\n",
"step 770/938 - loss: 0.1233 - acc_top1: 0.9272 - acc_top2: 0.9675 - 14ms/step\n",
"step 780/938 - loss: 0.2208 - acc_top1: 0.9275 - acc_top2: 0.9677 - 14ms/step\n",
"step 790/938 - loss: 0.0599 - acc_top1: 0.9281 - acc_top2: 0.9680 - 14ms/step\n",
"step 800/938 - loss: 0.0270 - acc_top1: 0.9287 - acc_top2: 0.9683 - 14ms/step\n",
"step 810/938 - loss: 0.1546 - acc_top1: 0.9291 - acc_top2: 0.9687 - 14ms/step\n",
"step 820/938 - loss: 0.0252 - acc_top1: 0.9297 - acc_top2: 0.9689 - 14ms/step\n",
"step 830/938 - loss: 0.0276 - acc_top1: 0.9304 - acc_top2: 0.9693 - 14ms/step\n",
"step 840/938 - loss: 0.0620 - acc_top1: 0.9309 - acc_top2: 0.9695 - 14ms/step\n",
"step 850/938 - loss: 0.0505 - acc_top1: 0.9314 - acc_top2: 0.9699 - 14ms/step\n",
"step 860/938 - loss: 0.0156 - acc_top1: 0.9319 - acc_top2: 0.9701 - 14ms/step\n",
"step 870/938 - loss: 0.0229 - acc_top1: 0.9325 - acc_top2: 0.9704 - 14ms/step\n",
"step 880/938 - loss: 0.0498 - acc_top1: 0.9330 - acc_top2: 0.9707 - 14ms/step\n",
"step 890/938 - loss: 0.0183 - acc_top1: 0.9335 - acc_top2: 0.9710 - 14ms/step\n",
"step 900/938 - loss: 0.1282 - acc_top1: 0.9339 - acc_top2: 0.9712 - 14ms/step\n",
"step 910/938 - loss: 0.0426 - acc_top1: 0.9342 - acc_top2: 0.9715 - 14ms/step\n",
"step 920/938 - loss: 0.0641 - acc_top1: 0.9347 - acc_top2: 0.9717 - 14ms/step\n",
"step 930/938 - loss: 0.0745 - acc_top1: 0.9351 - acc_top2: 0.9719 - 14ms/step\n",
"step 938/938 - loss: 0.0118 - acc_top1: 0.9354 - acc_top2: 0.9721 - 14ms/step\n",
"save checkpoint at mnist_checkpoint/0\n",
"Eval begin...\n",
"step 10/157 - loss: 0.1032 - acc_top1: 0.9828 - acc_top2: 0.9969 - 5ms/step\n",
"step 20/157 - loss: 0.2664 - acc_top1: 0.9781 - acc_top2: 0.9953 - 5ms/step\n",
"step 30/157 - loss: 0.1626 - acc_top1: 0.9766 - acc_top2: 0.9943 - 5ms/step\n",
"step 40/157 - loss: 0.0247 - acc_top1: 0.9734 - acc_top2: 0.9926 - 5ms/step\n",
"step 50/157 - loss: 0.0225 - acc_top1: 0.9738 - acc_top2: 0.9925 - 5ms/step\n",
"step 60/157 - loss: 0.2119 - acc_top1: 0.9737 - acc_top2: 0.9927 - 5ms/step\n",
"step 70/157 - loss: 0.0559 - acc_top1: 0.9723 - acc_top2: 0.9920 - 5ms/step\n",
"step 80/157 - loss: 0.0329 - acc_top1: 0.9725 - acc_top2: 0.9918 - 5ms/step\n",
"step 90/157 - loss: 0.1064 - acc_top1: 0.9741 - acc_top2: 0.9925 - 5ms/step\n",
"step 100/157 - loss: 0.0027 - acc_top1: 0.9744 - acc_top2: 0.9923 - 5ms/step\n",
"step 110/157 - loss: 0.0044 - acc_top1: 0.9750 - acc_top2: 0.9925 - 5ms/step\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step 120/157 - loss: 0.0093 - acc_top1: 0.9768 - acc_top2: 0.9931 - 5ms/step\n",
"step 130/157 - loss: 0.1247 - acc_top1: 0.9774 - acc_top2: 0.9935 - 5ms/step\n",
"step 140/157 - loss: 0.0031 - acc_top1: 0.9785 - acc_top2: 0.9940 - 5ms/step\n",
"step 150/157 - loss: 0.0495 - acc_top1: 0.9794 - acc_top2: 0.9944 - 5ms/step\n",
"step 157/157 - loss: 0.0020 - acc_top1: 0.9790 - acc_top2: 0.9944 - 5ms/step\n",
"Eval samples: 10000\n",
"Epoch 2/2\n",
"step 10/938 - loss: 0.1735 - acc_top1: 0.9766 - acc_top2: 0.9938 - 16ms/step\n",
"step 20/938 - loss: 0.0723 - acc_top1: 0.9750 - acc_top2: 0.9922 - 15ms/step\n",
"step 30/938 - loss: 0.0593 - acc_top1: 0.9781 - acc_top2: 0.9927 - 15ms/step\n",
"step 40/938 - loss: 0.1243 - acc_top1: 0.9793 - acc_top2: 0.9938 - 15ms/step\n",
"step 50/938 - loss: 0.0127 - acc_top1: 0.9797 - acc_top2: 0.9944 - 15ms/step\n",
"step 60/938 - loss: 0.0319 - acc_top1: 0.9779 - acc_top2: 0.9938 - 15ms/step\n",
"step 70/938 - loss: 0.0404 - acc_top1: 0.9783 - acc_top2: 0.9946 - 15ms/step\n",
"step 80/938 - loss: 0.1120 - acc_top1: 0.9781 - acc_top2: 0.9943 - 15ms/step\n",
"step 90/938 - loss: 0.0222 - acc_top1: 0.9780 - acc_top2: 0.9944 - 15ms/step\n",
"step 100/938 - loss: 0.0726 - acc_top1: 0.9788 - acc_top2: 0.9948 - 15ms/step\n",
"step 110/938 - loss: 0.0255 - acc_top1: 0.9790 - acc_top2: 0.9952 - 15ms/step\n",
"step 120/938 - loss: 0.2556 - acc_top1: 0.9790 - acc_top2: 0.9948 - 15ms/step\n",
"step 130/938 - loss: 0.0795 - acc_top1: 0.9786 - acc_top2: 0.9945 - 15ms/step\n",
"step 140/938 - loss: 0.1106 - acc_top1: 0.9785 - acc_top2: 0.9944 - 15ms/step\n",
"step 150/938 - loss: 0.0564 - acc_top1: 0.9784 - acc_top2: 0.9946 - 15ms/step\n",
"step 160/938 - loss: 0.1016 - acc_top1: 0.9784 - acc_top2: 0.9947 - 15ms/step\n",
"step 170/938 - loss: 0.0665 - acc_top1: 0.9785 - acc_top2: 0.9946 - 15ms/step\n",
"step 180/938 - loss: 0.0443 - acc_top1: 0.9788 - acc_top2: 0.9946 - 15ms/step\n",
"step 190/938 - loss: 0.0696 - acc_top1: 0.9789 - acc_top2: 0.9947 - 15ms/step\n",
"step 200/938 - loss: 0.0552 - acc_top1: 0.9791 - acc_top2: 0.9948 - 15ms/step\n",
"step 210/938 - loss: 0.1540 - acc_top1: 0.9789 - acc_top2: 0.9946 - 15ms/step\n",
"step 220/938 - loss: 0.0422 - acc_top1: 0.9791 - acc_top2: 0.9947 - 15ms/step\n",
"step 230/938 - loss: 0.2994 - acc_top1: 0.9791 - acc_top2: 0.9946 - 15ms/step\n",
"step 240/938 - loss: 0.0246 - acc_top1: 0.9791 - acc_top2: 0.9946 - 15ms/step\n",
"step 250/938 - loss: 0.0802 - acc_top1: 0.9788 - acc_top2: 0.9946 - 15ms/step\n",
"step 260/938 - loss: 0.1142 - acc_top1: 0.9787 - acc_top2: 0.9947 - 15ms/step\n",
"step 270/938 - loss: 0.0195 - acc_top1: 0.9785 - acc_top2: 0.9946 - 15ms/step\n",
"step 280/938 - loss: 0.0559 - acc_top1: 0.9785 - acc_top2: 0.9944 - 15ms/step\n",
"step 290/938 - loss: 0.1101 - acc_top1: 0.9786 - acc_top2: 0.9943 - 15ms/step\n",
"step 300/938 - loss: 0.0078 - acc_top1: 0.9786 - acc_top2: 0.9943 - 15ms/step\n",
"step 310/938 - loss: 0.0877 - acc_top1: 0.9789 - acc_top2: 0.9944 - 15ms/step\n",
"step 320/938 - loss: 0.0919 - acc_top1: 0.9790 - acc_top2: 0.9945 - 15ms/step\n",
"step 330/938 - loss: 0.0395 - acc_top1: 0.9789 - acc_top2: 0.9945 - 15ms/step\n",
"step 340/938 - loss: 0.1892 - acc_top1: 0.9787 - acc_top2: 0.9945 - 15ms/step\n",
"step 350/938 - loss: 0.0457 - acc_top1: 0.9784 - acc_top2: 0.9944 - 15ms/step\n",
"step 360/938 - loss: 0.1036 - acc_top1: 0.9786 - acc_top2: 0.9944 - 15ms/step\n",
"step 370/938 - loss: 0.0614 - acc_top1: 0.9785 - acc_top2: 0.9944 - 15ms/step\n",
"step 380/938 - loss: 0.2316 - acc_top1: 0.9787 - acc_top2: 0.9944 - 15ms/step\n",
"step 390/938 - loss: 0.0126 - acc_top1: 0.9788 - acc_top2: 0.9945 - 15ms/step\n",
"step 400/938 - loss: 0.0614 - acc_top1: 0.9789 - acc_top2: 0.9946 - 15ms/step\n",
"step 410/938 - loss: 0.0374 - acc_top1: 0.9788 - acc_top2: 0.9945 - 15ms/step\n",
"step 420/938 - loss: 0.0924 - acc_top1: 0.9788 - acc_top2: 0.9945 - 15ms/step\n",
"step 430/938 - loss: 0.0151 - acc_top1: 0.9791 - acc_top2: 0.9946 - 15ms/step\n",
"step 440/938 - loss: 0.0223 - acc_top1: 0.9791 - acc_top2: 0.9947 - 15ms/step\n",
"step 450/938 - loss: 0.0111 - acc_top1: 0.9793 - acc_top2: 0.9947 - 15ms/step\n",
"step 460/938 - loss: 0.0112 - acc_top1: 0.9793 - acc_top2: 0.9947 - 15ms/step\n",
"step 470/938 - loss: 0.0239 - acc_top1: 0.9794 - acc_top2: 0.9947 - 15ms/step\n",
"step 480/938 - loss: 0.0821 - acc_top1: 0.9795 - acc_top2: 0.9948 - 15ms/step\n",
"step 490/938 - loss: 0.0493 - acc_top1: 0.9796 - acc_top2: 0.9948 - 15ms/step\n",
"step 500/938 - loss: 0.0627 - acc_top1: 0.9797 - acc_top2: 0.9949 - 15ms/step\n",
"step 510/938 - loss: 0.0331 - acc_top1: 0.9797 - acc_top2: 0.9949 - 15ms/step\n",
"step 520/938 - loss: 0.0831 - acc_top1: 0.9797 - acc_top2: 0.9949 - 15ms/step\n",
"step 530/938 - loss: 0.0687 - acc_top1: 0.9796 - acc_top2: 0.9949 - 15ms/step\n",
"step 540/938 - loss: 0.1556 - acc_top1: 0.9794 - acc_top2: 0.9949 - 15ms/step\n",
"step 550/938 - loss: 0.2394 - acc_top1: 0.9795 - acc_top2: 0.9950 - 15ms/step\n",
"step 560/938 - loss: 0.0353 - acc_top1: 0.9794 - acc_top2: 0.9950 - 15ms/step\n",
"step 570/938 - loss: 0.0179 - acc_top1: 0.9794 - acc_top2: 0.9951 - 15ms/step\n",
"step 580/938 - loss: 0.0307 - acc_top1: 0.9796 - acc_top2: 0.9951 - 15ms/step\n",
"step 590/938 - loss: 0.0806 - acc_top1: 0.9796 - acc_top2: 0.9952 - 15ms/step\n",
"step 600/938 - loss: 0.0320 - acc_top1: 0.9796 - acc_top2: 0.9953 - 15ms/step\n",
"step 610/938 - loss: 0.0201 - acc_top1: 0.9798 - acc_top2: 0.9953 - 15ms/step\n",
"step 620/938 - loss: 0.1524 - acc_top1: 0.9797 - acc_top2: 0.9953 - 15ms/step\n",
"step 630/938 - loss: 0.0062 - acc_top1: 0.9797 - acc_top2: 0.9953 - 15ms/step\n",
"step 640/938 - loss: 0.0908 - acc_top1: 0.9798 - acc_top2: 0.9953 - 15ms/step\n",
"step 650/938 - loss: 0.0467 - acc_top1: 0.9799 - acc_top2: 0.9954 - 15ms/step\n",
"step 660/938 - loss: 0.0156 - acc_top1: 0.9801 - acc_top2: 0.9954 - 15ms/step\n",
"step 670/938 - loss: 0.0318 - acc_top1: 0.9802 - acc_top2: 0.9955 - 15ms/step\n",
"step 680/938 - loss: 0.0133 - acc_top1: 0.9804 - acc_top2: 0.9955 - 15ms/step\n",
"step 690/938 - loss: 0.0651 - acc_top1: 0.9805 - acc_top2: 0.9956 - 15ms/step\n",
"step 700/938 - loss: 0.0052 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n",
"step 710/938 - loss: 0.1208 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n",
"step 720/938 - loss: 0.1519 - acc_top1: 0.9805 - acc_top2: 0.9956 - 15ms/step\n",
"step 730/938 - loss: 0.0954 - acc_top1: 0.9805 - acc_top2: 0.9955 - 15ms/step\n",
"step 740/938 - loss: 0.0059 - acc_top1: 0.9806 - acc_top2: 0.9955 - 15ms/step\n",
"step 750/938 - loss: 0.1000 - acc_top1: 0.9805 - acc_top2: 0.9955 - 15ms/step\n",
"step 760/938 - loss: 0.0629 - acc_top1: 0.9805 - acc_top2: 0.9955 - 15ms/step\n",
"step 770/938 - loss: 0.0182 - acc_top1: 0.9804 - acc_top2: 0.9955 - 15ms/step\n",
"step 780/938 - loss: 0.0215 - acc_top1: 0.9804 - acc_top2: 0.9955 - 15ms/step\n",
"step 790/938 - loss: 0.0418 - acc_top1: 0.9804 - acc_top2: 0.9956 - 15ms/step\n",
"step 800/938 - loss: 0.0132 - acc_top1: 0.9805 - acc_top2: 0.9956 - 15ms/step\n",
"step 810/938 - loss: 0.0546 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n",
"step 820/938 - loss: 0.0373 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n",
"step 830/938 - loss: 0.0965 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n",
"step 840/938 - loss: 0.0143 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n",
"step 850/938 - loss: 0.0578 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n",
"step 860/938 - loss: 0.0205 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n",
"step 870/938 - loss: 0.0384 - acc_top1: 0.9808 - acc_top2: 0.9956 - 15ms/step\n",
"step 880/938 - loss: 0.0157 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n",
"step 890/938 - loss: 0.0457 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n",
"step 900/938 - loss: 0.0202 - acc_top1: 0.9808 - acc_top2: 0.9956 - 15ms/step\n",
"step 910/938 - loss: 0.0240 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n",
"step 920/938 - loss: 0.0585 - acc_top1: 0.9808 - acc_top2: 0.9956 - 15ms/step\n",
"step 930/938 - loss: 0.0414 - acc_top1: 0.9809 - acc_top2: 0.9956 - 15ms/step\n",
"step 938/938 - loss: 0.0180 - acc_top1: 0.9809 - acc_top2: 0.9956 - 15ms/step\n",
"save checkpoint at mnist_checkpoint/1\n",
"Eval begin...\n",
"step 10/157 - loss: 0.1093 - acc_top1: 0.9828 - acc_top2: 0.9984 - 5ms/step\n",
"step 20/157 - loss: 0.2292 - acc_top1: 0.9789 - acc_top2: 0.9969 - 5ms/step\n",
"step 30/157 - loss: 0.1203 - acc_top1: 0.9797 - acc_top2: 0.9969 - 5ms/step\n",
"step 40/157 - loss: 0.0068 - acc_top1: 0.9773 - acc_top2: 0.9961 - 5ms/step\n",
"step 50/157 - loss: 0.0049 - acc_top1: 0.9775 - acc_top2: 0.9959 - 5ms/step\n",
"step 60/157 - loss: 0.0399 - acc_top1: 0.9779 - acc_top2: 0.9956 - 5ms/step\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step 70/157 - loss: 0.0299 - acc_top1: 0.9768 - acc_top2: 0.9953 - 5ms/step\n",
"step 80/157 - loss: 0.0108 - acc_top1: 0.9771 - acc_top2: 0.9955 - 5ms/step\n",
"step 90/157 - loss: 0.0209 - acc_top1: 0.9793 - acc_top2: 0.9958 - 5ms/step\n",
"step 100/157 - loss: 0.0031 - acc_top1: 0.9806 - acc_top2: 0.9962 - 5ms/step\n",
"step 110/157 - loss: 4.0509e-04 - acc_top1: 0.9808 - acc_top2: 0.9962 - 5ms/step\n",
"step 120/157 - loss: 8.9143e-04 - acc_top1: 0.9820 - acc_top2: 0.9965 - 5ms/step\n",
"step 130/157 - loss: 0.0119 - acc_top1: 0.9833 - acc_top2: 0.9968 - 5ms/step\n",
"step 140/157 - loss: 6.7999e-04 - acc_top1: 0.9844 - acc_top2: 0.9970 - 5ms/step\n",
"step 150/157 - loss: 0.0047 - acc_top1: 0.9853 - acc_top2: 0.9972 - 5ms/step\n",
"step 157/157 - loss: 1.6522e-04 - acc_top1: 0.9847 - acc_top2: 0.9973 - 5ms/step\n",
"Eval samples: 10000\n",
"save checkpoint at mnist_checkpoint/final\n"
]
}
],
"source": [
"model.fit(train_dataset,\n",
" test_dataset,\n",
" epochs=2,\n",
" batch_size=64,\n",
" save_dir='mnist_checkpoint')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 组网&训练方式1结束\n",
"以上就是组网&训练方式1,可以非常快速的完成网络模型的构建与训练。此外,paddle还可以用下面的方式来完成模型的训练。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3.组网&训练方式2\n",
"方式1可以快速便捷的完成组网&训练,将细节都隐藏了起来。而方式2则可以用最基本的方式,完成模型的组网与训练。具体如下。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 通过继承Layer的方式来构建模型"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"class LeNet(paddle.nn.Layer):\n",
" def __init__(self):\n",
" super(LeNet, self).__init__()\n",
" self.conv1 = paddle.nn.Conv2D(num_channels=1, num_filters=6, filter_size=5, stride=1, padding=2, act='relu')\n",
" self.max_pool1 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n",
" self.conv2 = paddle.nn.Conv2D(num_channels=6, num_filters=16, filter_size=5, stride=1, act='relu')\n",
" self.max_pool2 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n",
" self.linear1 = paddle.nn.Linear(input_dim=16*5*5, output_dim=120, act='relu')\n",
" self.linear2 = paddle.nn.Linear(input_dim=120, output_dim=84, act='relu')\n",
" self.linear3 = paddle.nn.Linear(input_dim=84, output_dim=10,act='softmax')\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = self.max_pool1(x)\n",
" x = self.conv2(x)\n",
" x = self.max_pool2(x)\n",
" x = paddle.reshape(x, shape=[-1, 16*5*5])\n",
" x = self.linear1(x)\n",
" x = self.linear2(x)\n",
" x = self.linear3(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 训练模型"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, batch_id: 0, loss is: [2.2982373], acc is: [0.15625]\n",
"epoch: 0, batch_id: 100, loss is: [0.25794172], acc is: [0.96875]\n",
"epoch: 0, batch_id: 200, loss is: [0.25025752], acc is: [0.984375]\n",
"epoch: 0, batch_id: 300, loss is: [0.17673397], acc is: [0.984375]\n",
"epoch: 0, batch_id: 400, loss is: [0.09535598], acc is: [1.]\n",
"epoch: 0, batch_id: 500, loss is: [0.08496016], acc is: [1.]\n",
"epoch: 0, batch_id: 600, loss is: [0.14111154], acc is: [0.984375]\n",
"epoch: 0, batch_id: 700, loss is: [0.07322718], acc is: [0.984375]\n",
"epoch: 0, batch_id: 800, loss is: [0.2417614], acc is: [0.984375]\n",
"epoch: 0, batch_id: 900, loss is: [0.10721541], acc is: [1.]\n",
"epoch: 1, batch_id: 0, loss is: [0.02449418], acc is: [1.]\n",
"epoch: 1, batch_id: 100, loss is: [0.151768], acc is: [0.984375]\n",
"epoch: 1, batch_id: 200, loss is: [0.06956144], acc is: [0.984375]\n",
"epoch: 1, batch_id: 300, loss is: [0.2008793], acc is: [1.]\n",
"epoch: 1, batch_id: 400, loss is: [0.03839134], acc is: [1.]\n",
"epoch: 1, batch_id: 500, loss is: [0.0217573], acc is: [1.]\n",
"epoch: 1, batch_id: 600, loss is: [0.10977131], acc is: [0.984375]\n",
"epoch: 1, batch_id: 700, loss is: [0.02774046], acc is: [1.]\n",
"epoch: 1, batch_id: 800, loss is: [0.13530938], acc is: [0.984375]\n",
"epoch: 1, batch_id: 900, loss is: [0.0282761], acc is: [1.]\n"
]
}
],
"source": [
"import paddle\n",
"train_loader = paddle.io.DataLoader(train_dataset, places=paddle.CPUPlace(), batch_size=64)\n",
"def train(model):\n",
" model.train()\n",
" epochs = 2\n",
" batch_size = 64\n",
" optim = paddle.optimizer.Adam(learning_rate=0.001, parameter_list=model.parameters())\n",
" for epoch in range(epochs):\n",
" for batch_id, data in enumerate(train_loader()):\n",
" x_data = data[0]\n",
" y_data = data[1]\n",
" predicts = model(x_data)\n",
" loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n",
" acc = paddle.metric.accuracy(predicts, y_data, k=2)\n",
" avg_loss = paddle.mean(loss)\n",
" avg_acc = paddle.mean(acc)\n",
" avg_loss.backward()\n",
" if batch_id % 100 == 0:\n",
" print(\"epoch: {}, batch_id: {}, loss is: {}, acc is: {}\".format(epoch, batch_id, avg_loss.numpy(), avg_acc.numpy()))\n",
" optim.minimize(avg_loss)\n",
" model.clear_gradients()\n",
"model = LeNet()\n",
"train(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 对模型进行验证"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"batch_id: 0, loss is: [0.0054796], acc is: [1.]\n",
"batch_id: 100, loss is: [0.12248081], acc is: [0.984375]\n",
"batch_id: 200, loss is: [0.06583288], acc is: [1.]\n",
"batch_id: 300, loss is: [0.07927508], acc is: [1.]\n",
"batch_id: 400, loss is: [0.02623187], acc is: [1.]\n",
"batch_id: 500, loss is: [0.02039231], acc is: [1.]\n",
"batch_id: 600, loss is: [0.03374948], acc is: [1.]\n",
"batch_id: 700, loss is: [0.05141395], acc is: [1.]\n",
"batch_id: 800, loss is: [0.1005884], acc is: [1.]\n",
"batch_id: 900, loss is: [0.03581202], acc is: [1.]\n"
]
}
],
"source": [
"import paddle\n",
"test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64)\n",
"def test(model):\n",
" model.eval()\n",
" batch_size = 64\n",
" for batch_id, data in enumerate(train_loader()):\n",
" x_data = data[0]\n",
" y_data = data[1]\n",
" predicts = model(x_data)\n",
" loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n",
" acc = paddle.metric.accuracy(predicts, y_data, k=2)\n",
" avg_loss = paddle.mean(loss)\n",
" avg_acc = paddle.mean(acc)\n",
" avg_loss.backward()\n",
" if batch_id % 100 == 0:\n",
" print(\"batch_id: {}, loss is: {}, acc is: {}\".format(batch_id, avg_loss.numpy(), avg_acc.numpy()))\n",
"test(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 组网&训练方式2结束\n",
"以上就是组网&训练方式2,通过这种方式,可以清楚的看到训练和测试中的每一步过程。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 总结\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"以上就是用LeNet对手写数字数据及MNIST进行分类。本示例提供了两种训练模型的方式,一种可以快速完成模型的组建与预测,非常适合新手用户上手。另一种则需要多个步骤来完成模型的训练,适合进阶用户使用。"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
paddle2.0_docs/n_gram_model/n_gram_model.ipynb
0 → 100644
浏览文件 @
6cbf7fd8
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 用N-Gram模型在莎士比亚诗中训练word embedding\n",
"N-gram 是计算机语言学和概率论范畴内的概念,是指给定的一段文本中N个项目的序列。\n",
"N=1 时 N-gram 又称为 unigram,N=2 称为 bigram,N=3 称为 trigram,以此类推。实际应用通常采用 bigram 和 trigram 进行计算。\n",
"本示例在莎士比亚十四行诗上实现了trigram。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 环境\n",
"本教程基于paddle2.0-alpha编写,如果您的环境不是本版本,请先安装paddle2.0-alpha。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'2.0.0-alpha0'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import paddle\n",
"paddle.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 数据集&&相关参数\n",
"训练数据集采用了莎士比亚十四行诗,CONTEXT_SIZE设为2,意味着是trigram。EMBEDDING_DIM设为10。"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"CONTEXT_SIZE = 2\n",
"EMBEDDING_DIM = 10\n",
"\n",
"test_sentence = \"\"\"When forty winters shall besiege thy brow,\n",
"And dig deep trenches in thy beauty's field,\n",
"Thy youth's proud livery so gazed on now,\n",
"Will be a totter'd weed of small worth held:\n",
"Then being asked, where all thy beauty lies,\n",
"Where all the treasure of thy lusty days;\n",
"To say, within thine own deep sunken eyes,\n",
"Were an all-eating shame, and thriftless praise.\n",
"How much more praise deserv'd thy beauty's use,\n",
"If thou couldst answer 'This fair child of mine\n",
"Shall sum my count, and make my old excuse,'\n",
"Proving his beauty by succession thine!\n",
"This were to be new made when thou art old,\n",
"And see thy blood warm when thou feel'st it cold.\"\"\".split()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 数据预处理\n",
"将文本被拆成了元组的形式,格式为(('第一个词', '第二个词'), '第三个词');其中,第三个词就是我们的目标。"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[(('When', 'forty'), 'winters'), (('forty', 'winters'), 'shall'), (('winters', 'shall'), 'besiege')]\n"
]
}
],
"source": [
"trigram = [((test_sentence[i], test_sentence[i + 1]), test_sentence[i + 2])\n",
" for i in range(len(test_sentence) - 2)]\n",
"\n",
"vocab = set(test_sentence)\n",
"word_to_idx = {word: i for i, word in enumerate(vocab)}\n",
"idx_to_word = {word_to_idx[word]: word for word in word_to_idx}\n",
"# 看一下数据集\n",
"print(trigram[:3])\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 构建`Dataset`类 加载数据"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"class TrainDataset(paddle.io.Dataset):\n",
" def __init__(self, tuple_data, vocab):\n",
" self.tuple_data = tuple_data\n",
" self.vocab = vocab\n",
"\n",
" def __getitem__(self, idx):\n",
" data = list(self.tuple_data[idx][0])\n",
" label = list(self.tuple_data[idx][1])\n",
" return data, label\n",
" \n",
" def __len__(self):\n",
" return len(self.tuple_data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 组网&训练\n",
"这里用paddle动态图的方式组网,由于是N-Gram模型,只需要一层`Embedding`与两层`Linear`就可以完成网络模型的构建。"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"import numpy as np\n",
"class NGramModel(paddle.nn.Layer):\n",
" def __init__(self, vocab_size, embedding_dim, context_size):\n",
" super(NGramModel, self).__init__()\n",
" self.embedding = paddle.nn.Embedding(size=[vocab_size, embedding_dim])\n",
" self.linear1 = paddle.nn.Linear(context_size * embedding_dim, 128)\n",
" self.linear2 = paddle.nn.Linear(128, vocab_size)\n",
"\n",
" def forward(self, x):\n",
" x = self.embedding(x)\n",
" x = paddle.reshape(x, [1, -1])\n",
" x = self.linear1(x)\n",
" x = paddle.nn.functional.relu(x)\n",
" x = self.linear2(x)\n",
" x = paddle.nn.functional.softmax(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 初始化Model,并定义相关的参数。"
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, loss is: [4.631529]\n",
"epoch: 50, loss is: [4.6081576]\n",
"epoch: 100, loss is: [4.600631]\n",
"epoch: 150, loss is: [4.603069]\n",
"epoch: 200, loss is: [4.592647]\n",
"epoch: 250, loss is: [4.5626693]\n",
"epoch: 300, loss is: [4.513106]\n",
"epoch: 350, loss is: [4.4345813]\n",
"epoch: 400, loss is: [4.3238697]\n",
"epoch: 450, loss is: [4.1728854]\n",
"epoch: 500, loss is: [3.9622664]\n",
"epoch: 550, loss is: [3.67673]\n",
"epoch: 600, loss is: [3.2998457]\n",
"epoch: 650, loss is: [2.8206367]\n",
"epoch: 700, loss is: [2.2514927]\n",
"epoch: 750, loss is: [1.6479329]\n",
"epoch: 800, loss is: [1.1147357]\n",
"epoch: 850, loss is: [0.73231363]\n",
"epoch: 900, loss is: [0.49481753]\n",
"epoch: 950, loss is: [0.3504072]\n"
]
}
],
"source": [
"vocab_size = len(vocab)\n",
"embedding_dim = 10\n",
"context_size = 2\n",
"\n",
"paddle.enable_imperative()\n",
"losses = []\n",
"def train(model):\n",
" model.train()\n",
" optim = paddle.optimizer.SGD(learning_rate=0.001, parameter_list=model.parameters())\n",
" for epoch in range(1000):\n",
" # 留最后10组作为预测\n",
" for context, target in trigram[:-10]:\n",
" context_idxs = list(map(lambda w: word_to_idx[w], context))\n",
" x_data = paddle.imperative.to_variable(np.array(context_idxs))\n",
" y_data = paddle.imperative.to_variable(np.array([word_to_idx[target]]))\n",
" predicts = model(x_data)\n",
" # print (predicts)\n",
" loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n",
" loss.backward()\n",
" optim.minimize(loss)\n",
" model.clear_gradients()\n",
" if epoch % 50 == 0:\n",
" print(\"epoch: {}, loss is: {}\".format(epoch, loss.numpy()))\n",
" losses.append(loss.numpy())\n",
"model = NGramModel(vocab_size, embedding_dim, context_size)\n",
"train(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 打印loss下降曲线\n",
"通过可视化loss的曲线,可以看到模型训练的效果。"
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x1434a9358>]"
]
},
"execution_count": 123,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAdkklEQVR4nO3dd3xV9f3H8dfn3uwBmUAggRghMmUFDQ5Qq4JC3VqtVmuttL/aqq3WamtdtcPaarWOuq1aR0XrAJVaxKKIIghE9oaAAQIBAiE7398fuVhUJCHk5tzxfj4e95Hcc06S98nJ451zv/cMc84hIiKhy+d1ABER2T8VtYhIiFNRi4iEOBW1iEiIU1GLiIS4mGB806ysLJefnx+Mby0iEpHmzJmzxTmXva95QSnq/Px8Zs+eHYxvLSISkcxs7dfN09CHiEiIU1GLiIQ4FbWISIhTUYuIhDgVtYhIiFNRi4iEOBW1iEiIC8px1G318PSVNDmI8Vnzw+8jPsZHVko82anxZKXEk5kSR6xf/19EJHqEVFHf9fYyauqbWlwuIzmOrJQ4slPjyU5pLvCs1HhS4mNIjPWTGOcnIdZHjM+H32dffJi1OM3s4NclxucjMbY5h7XHNxSRqBVSRV1y8xgampqob3Q0NjkaGpuorm9ky646tuyqpXxn7Vc+zlm3jfKdta0qeK8kxvpJivOTEPiYGOf//B/K3tPTEuPITIkjMyWerOTmj5kpcaQnxeH3qexFolVIFXVcjI+4fQyb98pM3u/XOeeoqmtkd20DNfXN5V5d30hDYxONTY5G11z8X3nsa7pztMdNb5r/yQSy1DVQXd/I7rpGagIfq+saqaiqY8O2wPP6RnZU19PY9NUf7rPmVxGZyfF06RRPbnoiuelJ9EhLJDc9kR7piXRJTVCZi0SokCrqtjIzUuJjSIkP79VpanLsqK5na1Xt568itu6qY+uuWrZU1bFlZy2bKmt4u6ySLbvqvvC1sX4jp3OguNMCRZ6eGCj1RLp3TsSnIhcJS+HdbBHG5zPSk+NIT46jd5f9L1td18iG7dWs37Y78LGaDduan09fXs7mnbVfeGWQGOunsFsqfbum0jcnlcO6pdI/pxNpSXHBXSkROWgq6jCVGOend5cUendJ2ef82oZGyrbXsGF7NaUVu1m2aRdLNlby9uJNvDC79PPl8jOTGJyXxuDcNAbnpTGgeycSYv0dtRoi0goq6ggVH+MnPyuZ/Kwvju875yjfVcuSsp0s+GwHJaU7mLW6glfnfQY0HxrZL6cTg/M6MyI/gxH5GXRPS/RiFUQkwFx7vHP2JUVFRU7Xow4vmyprmFe6nfml25m/fjvzS3ewq7YBgNz0RI7Iz2DEIRkc0zuLvIwkj9OKRB4zm+OcK9rnPBW17Etjk2NxWSWzVlfw8ZoKZq2uYGtV8xuYBdnJjC7MZlRhNiMLMjVUItIOVNRy0JxzrCzfxfRlW/jvsnI+XLWV2oYmEmP9HN83mzEDunFC3y6kJsR6HVUkLKmopd3V1Dfy4aqt/GfxJqYs3ET5zlri/D6O6ZPFqYNyOKl/VzonqrRFWktFLUHV1OSYW7qNtxZs5M0FG1m/rZo4v49RhVmMOzyHE/t11Z62SAtU1NJhnHPMX7+DySWfMbmkjM921BAX42N0YTbjD8/hG/26hv2JSSLBoKIWTzTvaW9nckkZb3xaxsbKGuJjfIwd2I1zhudy1KFZOu1dJEBFLZ5ranLMWbeNV+dt4LV5n1FZ00BO5wTOGtaDc4bncUjW/q/nIhLpVNQSUmrqG5m6eDMT55Ty32XlNDkYVZjNpUfnM7pPtq5JIlFJRS0ha1NlDS98XMozH65l885aCrKSueSofM4enquxbIkqKmoJeXUNTby5oIzHZ6xhful2UhNiuHhkLy47poCMZF04SiKfilrCytx123jkvVW8uWAjCTF+LiruyeWjCuiSmuB1NJGgUVFLWFq+aScPvLuSV+dtIMbv4zvFvbji+N7aw5aIpKKWsLZmSxX3TVvBy5+sJzkuhh+MLuB7xxxCUpzGsCVyqKglIizbtJM7pyzl7UWbyE6N5/qxfTlzaA8dJSIRYX9F/dUbFIqEqMKuqTxycRETfziS7mmJXPPifM59aCYLNuzwOppIUKmoJewU5Wfwr/87ij+efThrtlRx2n3vc+Mrn7J9d13LXywShlTUEpZ8PuO8EXm8c+1xXDwyn2c/Wsfxf3qXf35cSjCG80S8pKKWsNY5MZZbThvA5CuPpU+XVK57qYTvPDaL0ordXkcTaTetLmoz85vZXDObFMxAIm3RL6cTz08o5vYzBjJ33TZOvns6T8xYTVOT9q4l/B3IHvVVwOJgBRE5WD6fcVFxL/79s9EccUgGt76+iHMfmsmKzbu8jiZyUFpV1GaWC4wDHg1uHJGD1yMtkScvHcFd5w1mZfkuTr33Pe6ftoL6xiavo4m0SWv3qP8CXAd87V+6mU0ws9lmNru8vLw9som0mZlx1rBc3v7paE7s14U7pyzl7Ac/YPWWKq+jiRywFovazMYDm51zc/a3nHPuYedckXOuKDs7u90CihyM7NR4HrhwOA9cOIy1W3cz7t73mDhnvY4MkbDSmj3qo4HTzGwN8Dxwgpk9E9RUIu3s1EE5vHnVsQzq0ZlrX5zPlc/PY2dNvdexRFqlxaJ2zt3gnMt1zuUD5wPvOOcuCnoykXbWPS2RZy8v5tqTC5lc8hmn3TeDJRsrvY4l0iIdRy1Rxe8zfnxCH569vJhdtQ2ccf8MXpxd6nUskf06oKJ2zr3rnBsfrDAiHaW4IJPJVx7D0Lx0fj6xhF+/skBHhUjI0h61RK0uqQk88/0j+cGoAp7+cC3ffWIWO3Zr3FpCj4paoprfZ9xwaj/uPOdwZq2u4IwHZrCyXCfISGhRUYsA5xbl8dzlxVRW13Pm/TP4cNVWryOJfE5FLRJQlJ/BK1ccTZdOCVz8+CzeXrTJ60gigIpa5AvyMpJ48Qcj6ZfTiR8+M4eJc9Z7HUlERS3yZenJcTz7/SM56tBMrn1xPn//YI3XkSTKqahF9iE5PoZHLyni5P5dufm1hTwxY7XXkSSKqahFvkZ8jJ/7LxzGmAFdufX1RTz+vspavKGiFtmPWL+P+749jFMGduO2SYt49L1VXkeSKKSiFmlBrN/HvRcM5dRB3bh98mKe1DCIdLAYrwOIhINYv497zh9KQ+Mn3PL6IpLjYzi3KM/rWBIltEct0kqxfh9//fZQju2TxS9eKmFySZnXkSRKqKhFDkB8jJ+HvjOcYT3TufqFuUxbstnrSBIFVNQiBygpLobHLx3BYd1S+eEzc3S6uQSdilqkDTolxPL3S48gLyOJy578mHml272OJBFMRS3SRpkp8Txz2ZFkpMRxyeOzWLpxp9eRJEKpqEUOQrfOCTz7/WLiY3xc+sQsNlXWeB1JIpCKWuQg5WUk8fh3R7C9up7vPfkxVbUNXkeSCKOiFmkHA3t05v5vD2NxWSU/eW4uDbqtl7QjFbVIOzm+bxduO30g7yzZzC2vL8Q553UkiRA6M1GkHV1U3IvSbbt56L+r6JmRxIRRh3odSSKAilqknf1iTF/WV1TzuzeW0CszmTEDunkdScKchj5E2pnPZ/z5vMEMzu3MNf+cz4rNOmxPDo6KWiQIEmL9PHjRcOJjfEx4eg47a+q9jiRhTEUtEiTd0xK5/8JhrN26m5/9cz5NTXpzUdpGRS0SRMUFmdw4rh9vL9rEfdNWeB1HwpSKWiTIvntUPmcN7cHd/1nG1MWbvI4jYUhFLRJkZsbvzhpE/5xOXP38PFaV7/I6koQZFbVIB0iIbb6OdYzfuOLZudTUN3odScKIilqkg+SmJ/Hn8wazuKyS2ycv8jqOhBEVtUgHOqFvVyaMKuCZD9cxqeQzr+NImFBRi3Swn485jKE907jhpU9Zu7XK6zgSBlTUIh0s1u/j3vOHYgY/eW4u9brSnrRARS3igbyMJO44+3BK1u/gr+/o+GrZPxW1iEdOGZTDWcN6cP+0FXyybpvXcSSEqahFPHTLaQPo1imBn70wj911ujOM7FuLRW1mCWY2y8zmm9lCM7u1I4KJRINOCbH8+bzBrK3YzW8nL/Y6joSo1uxR1wInOOcGA0OAsWZWHNRUIlGkuCCTy48t4B8frWPaks1ex5EQ1GJRu2Z7znmNDTx0GTCRdnTNyYX07ZbKzyeWUFFV53UcCTGtGqM2M7+ZzQM2A2875z7axzITzGy2mc0uLy9v55gikS0+xs/d3xpCZXU9N7xcovstyhe0qqidc43OuSFALnCEmQ3cxzIPO+eKnHNF2dnZ7RxTJPL1y+nEtWMKmbJwE/+au8HrOBJCDuioD+fcdmAaMDYoaUSi3GXHFDC8Vzq3TVpE+c5ar+NIiGjNUR/ZZpYW+DwROAlYEuRcIlHJ7zPuOHsQu2sbueW1hV7HkRDRmj3qHGCamZUAH9M8Rj0puLFEolfvLqlcdWIfJn9axlsLNnodR0JATEsLOOdKgKEdkEVEAiaMKmBySRm/fnUBIwsy6ZwU63Uk8ZDOTBQJQbF+H38853Aqqur4ja5dHfVU1CIhamCPzvxwdAET56xn+jId8hrNVNQiIewnJ/Th0Oxkbnj5U6pqdS2QaKWiFglhCbF+7jj7cDZsr+beqcu9jiMeUVGLhLii/AzOH5HHY++vZunGnV7HEQ+oqEXCwHVj+5KSEMOvX1mg08ujkIpaJAxkJMdx/di+zFpTwcuf6PTyaKOiFgkT5xXlMbRnGr9/czE7dtd7HUc6kIpaJEz4fMbtZwykoqqOP/17qddxpAOpqEXCyIDunbl4ZD7PfLSWkvXbvY4jHURFLRJmfnZyIVkp8dz4ygIam/TGYjRQUYuEmU4Jsdw4rh8l63fw7Kx1XseRDqCiFglDpw3uzlGHZnLnW0vYskvXrY50KmqRMGRm3Hb6QKrrG/n9G7o8fKRTUYuEqd5dUrj82AJe+mQ9H63a6nUcCSIVtUgY+8kJfeiRlshNry6kobHJ6zgSJCpqkTCWGOfn1+P7sXTTTp77uNTrOBIkKmqRMDdmQDeKCzK4699LdcZihFJRi4Q5M+Om8QPYUV3PX6Yu8zqOBIGKWiQC9O/eifOP6MlTM9eyYrMuhRppVNQiEeKakwpJivNz26TFuhRqhFFRi0SIzJR4rvpGH6YvK2fa0s1ex5F2pKIWiSAXj8ynIDuZ2yctpq5Bh+tFChW1SASJi/Hx63H9WbWliqc/XOt1HGknKmqRCHN83y4c2yeLe6cu1+F6EUJFLRKBfnlqPypr6rlvmu5cHglU1CIRqF9OJ84dnsvfP1jLuq27vY4jB0lFLRKhrjn5MPw+444purpeuFNRi0Sorp0SmDCqgMklZcxZu83rOHIQVNQiEWzCqAKyU+P53Rs6CSacqahFIlhyfAzXnFTInLXbeGvBRq/jSBupqEUi3LlFeRzWNZU/vLVEJ8GEKRW1SITz+4xfjuvH2q27dRJMmFJRi0SB0YXZOgkmjKmoRaKEToIJXypqkSihk2DCV4tFbWZ5ZjbNzBaZ2UIzu6ojgolI+9NJMOGpNXvUDcA1zrn+QDFwhZn1D24sEQkGnQQTnlosaudcmXPuk8DnO4HFQI9gBxOR4NBJMOHngMaozSwfGAp8tI95E8xstpnNLi8vb6d4ItLe9j4J5k2dBBMWWl3UZpYCvARc7Zyr/PJ859zDzrki51xRdnZ2e2YUkXb2+Ukwb+okmHDQqqI2s1iaS/ofzrmXgxtJRIJtz0kw6yp289TMNV7HkRa05qgPAx4DFjvn7gp+JBHpCKMLsxldmM09U5dTUVXndRzZj9bsUR8NfAc4wczmBR6nBjmXiHSAX43rR1VtA/dO1UkwoSympQWcc+8D1gFZRKSDFXZN5YIjevL0h2u5qLgXvbukeB1J9kFnJopEuZ+eVEhSrJ/fv7HY6yjyNVTUIlEuKyWeK07ozdQlm3l/+Rav48g+qKhFhO8elU9ueiK3T15EY5NOggk1KmoRISHWz/Wn9GXJxp1MnFPqdRz5EhW1iAAwblAOw3ulc+eUZeyqbfA6juxFRS0iAJgZN47rx5Zdtfzt3ZVex5G9qKhF5HNDe6Zz+pDuPPLeKjZsr/Y6jgSoqEXkC64b2xeAP76la1aHChW1iHxBj7REvn/sIbw67zPmrtM1q0OBilpEvuL/jutNVko8t0/WNatDgYpaRL4iJT6Ga09uvmb15E/LvI4T9VTUIrJP5xbl0bdbKr9/YwnVdY1ex4lqKmoR2Se/z7j1tAFs2F7Ng++u8DpOVFNRi8jXOrIgk9OHdOdv01exdmuV13GilopaRPbrl6f2I9Zn/GbSIq+jRC0VtYjsV9dOCVz5jT78Z/Fm3lmyyes4UUlFLSItuvToQyjITubW1xdRU683FjuailpEWhQX4+OWbw5g7dbdPPb+aq/jRB0VtYi0yqjCbMYO6MZ976zgM10HpEOpqEWk1W4c348m5/itbtvVoVTUItJquelJXHF8byaXlDF9WbnXcaKGilpEDsiEUQUUZCXzq1c+1RmLHURFLSIHJCHWz2/PHERpRTX3TF3udZyooKIWkQM28tBMzh2eyyPvrWJxWaXXcSKeilpE2uSXp/ajc2IsN7z8qe5cHmQqahFpk/TkOG4a3595pdt55sO1XseJaCpqEWmz04d059g+Wdzx1hJKK3Z7HSdiqahFpM3MjN+fNQifGddNLKFJQyBBoaIWkYOSm57EjeP6MXPVVp7WEEhQqKhF5KB9a0Qeowuz+cObS1izRdetbm8qahE5aGbGH84eRIzf+PnE+RoCaWcqahFpFzmdE7n5mwP4eM02Hn1/lddxIoqKWkTazdnDejBmQFfunLKUBRt2eB0nYqioRaTdmBl/OOtwMpPjufK5uVTVNngdKSKoqEWkXaUnx3H3t4awemsVt72u+yy2BxW1iLS7kYdm8qPjDuWF2aVMLinzOk7Ya7GozexxM9tsZgs6IpCIRIarTyxkaM80fvFSCSvLd3kdJ6y1Zo/6SWBskHOISISJ9fu4/9vDiI/x8cOn57BL49Vt1mJRO+emAxUdkEVEIkz3tET+esFQVpbv4rqJ83FOx1e3hcaoRSSojuqdxfWn9OWNTzfy8HQdX90W7VbUZjbBzGab2ezyct1LTUT+5/JjCxg3KIc73lrCjBVbvI4TdtqtqJ1zDzvnipxzRdnZ2e31bUUkApgZfzzncHp3SeFH//hEby4eIA19iEiHSI6P4bFLRhDrNy594mO27qr1OlLYaM3hec8BM4HDzGy9mV0W/FgiEonyMpJ45OIiNlXW8P2nZlNTr7uYt0Zrjvq4wDmX45yLdc7lOuce64hgIhKZhvZM557zhzCvdDtXPT+XhsYmryOFPA19iEiHGzswh5vG92fKwk26M0wrxHgdQESi06VHH0JVbQN/+vcyEuL8/PaMgZiZ17FCkopaRDxzxfG9qapr5MF3V5IY6+fGcf1U1vugohYRz5gZ1405jOq6Rh57fzVNznHT+P4q6y9RUYuIp8yMm7/ZHzN4YsYaqusa+e2Zg/D7VNZ7qKhFxHNmxk3j+5McF8N901awu66RP583mFi/jncAFbWIhAgz49oxh5EU7+ePby1la1UtD1w4nM6JsV5H85z+XYlISPnRcb3507mDmbW6grMf/IDSit1eR/KcilpEQs45w3N56ntHsrmyhjMfmMGctdu8juQpFbWIhKSRh2by8o+OJjk+hvMfnsmTM1ZH7fWsVdQiErJ6d0nhtSuOYXRhNre8vogrn58XlXc2V1GLSEjrnBTLw98p4udjDmNyyWeM/+v7zF0XXUMhKmoRCXk+n3HF8b159vJi6hqaOOdvM7nr7WXUR8kFnVTUIhI2igsyefPqYzl9SHfunbqcsx74gAUbdngdK+hU1CISVjolxHLXeUN48MJhlO2o4bT73ufW1xeys6be62hBo6IWkbB0yqAcpl4zmguP7MWTH6zhxLv+y6vzNkTkJVNV1CIStjonxvKbMwbyrx8dTVZKPFc9P4/T75/BBxF2A10VtYiEvSF5abz+42O467zBVFTV8e1HP+K7T8yKmPFrC8YB5EVFRW727Nnt/n1FRFpSU9/IUzPXcN87K6isaeD4w7L58Qm9Gd4rw+to+2Vmc5xzRfucp6IWkUhUWVPP0zPX8tj7q6moqqO4IIMJowo4rrALvhC8hKqKWkSi1u66Bp6bVcoj01exsbKGnhlJXFTck/OK8khLivM63udU1CIS9eobm5iycCNPzVzLrNUVxMf4GH94d84e3oPiQzI938tWUYuI7GVxWSVPf7iW1+Z9xq7aBnqkJXLG0O6cOTSX3l1SPMmkohYR2YfqukbeXryJlz9Zz/Rl5TQ5KOyawtgB3RgzsBv9czp12P0bVdQiIi3YXFnD5E/LmLJwI7NWV9DkIC8jkRP7dWVUYTbFh2SSGOcP2s9XUYuIHIAtu2r5z6JNvLVwIzNXbqW2oYm4GB9H5GcwqjCLo3tn0bdbp3a9Aa+KWkSkjWrqG5m1uoLpy8qZvrycZZt2AZCaEMPwXumMyM9gRH4Gh+d2JiG27Xvc+ytq3dxWRGQ/EmL9jCrMZlRhNgAbd9Tw4aqtzFpTwcerK3h36VIA4vw+huSl8fyE4nY/gkRFLSJyALp1TuCMoT04Y2gPALZV1TF77TZmr6lgR3V9UA7zU1GLiByE9OQ4TurflZP6dw3az9BFmUREQpyKWkQkxKmoRURCnIpaRCTEqahFREKcilpEJMSpqEVEQpyKWkQkxAXlWh9mVg6sbeOXZwGRdQvhlmmdo4PWOfIdzPr2cs5l72tGUIr6YJjZ7K+7MEmk0jpHB61z5AvW+mroQ0QkxKmoRURCXCgW9cNeB/CA1jk6aJ0jX1DWN+TGqEVE5ItCcY9aRET2oqIWEQlxIVPUZjbWzJaa2Qozu97rPO3FzPLMbJqZLTKzhWZ2VWB6hpm9bWbLAx/TA9PNzO4N/B5KzGyYt2vQdmbmN7O5ZjYp8PwQM/sosG4vmFlcYHp84PmKwPx8T4O3kZmlmdlEM1tiZovNbGSkb2cz+2ng73qBmT1nZgmRtp3N7HEz22xmC/aadsDb1cwuCSy/3MwuOZAMIVHUZuYH7gdOAfoDF5hZf29TtZsG4BrnXH+gGLgisG7XA1Odc32AqYHn0Pw76BN4TAAe7PjI7eYqYPFez+8A7nbO9Qa2AZcFpl8GbAtMvzuwXDi6B3jLOdcXGEzzukfsdjazHsCVQJFzbiDgB84n8rbzk8DYL007oO1qZhnAzcCRwBHAzXvKvVWcc54/gJHAlL2e3wDc4HWuIK3rq8BJwFIgJzAtB1ga+Pwh4IK9lv98uXB6ALmBP+ATgEmA0XzGVsyXtzkwBRgZ+DwmsJx5vQ4HuL6dgdVfzh3J2xnoAZQCGYHtNgkYE4nbGcgHFrR1uwIXAA/tNf0Ly7X0CIk9av63wfdYH5gWUQIv9YYCHwFdnXNlgVkbgT03XIuU38VfgOuApsDzTGC7c64h8Hzv9fp8nQPzdwSWDyeHAOXAE4HhnkfNLJkI3s7OuQ3An4B1QBnN220Okb2d9zjQ7XpQ2ztUijrimVkK8BJwtXOucu95rvlfbMQcJ2lm44HNzrk5XmfpQDHAMOBB59xQoIr/vRwGInI7pwOn0/xPqjuQzFeHCCJeR2zXUCnqDUDeXs9zA9MigpnF0lzS/3DOvRyYvMnMcgLzc4DNgemR8Ls4GjjNzNYAz9M8/HEPkGZmMYFl9l6vz9c5ML8zsLUjA7eD9cB659xHgecTaS7uSN7OJwKrnXPlzrl64GWat30kb+c9DnS7HtT2DpWi/hjoE3i3OI7mNyRe8zhTuzAzAx4DFjvn7tpr1mvAnnd+L6F57HrP9IsD7x4XAzv2eokVFpxzNzjncp1z+TRvy3eccxcC04BzAot9eZ33/C7OCSwfVnuezrmNQKmZHRaY9A1gERG8nWke8ig2s6TA3/medY7Y7byXA92uU4CTzSw98Erk5MC01vF6kH6vwfVTgWXASuBXXudpx/U6huaXRSXAvMDjVJrH5qYCy4H/ABmB5Y3mI2BWAp/S/I665+txEOt/HDAp8HkBMAtYAbwIxAemJwSerwjML/A6dxvXdQgwO7CtXwHSI307A7cCS4AFwNNAfKRtZ+A5msfg62l+5XRZW7Yr8L3Auq8ALj2QDDqFXEQkxIXK0IeIiHwNFbWISIhTUYuIhDgVtYhIiFNRi4iEOBW1iEiIU1GLiIS4/wecQTmUPnjbjwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as ticker\n",
"%matplotlib inline\n",
"\n",
"plt.figure()\n",
"plt.plot(losses)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 预测\n",
"用训练好的模型进行预测。"
]
},
{
"cell_type": "code",
"execution_count": 127,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"the input words is: praise., How\n",
"the predict words is: much\n",
"the true words is: much\n"
]
}
],
"source": [
"import random\n",
"def test(model):\n",
" model.eval()\n",
" # 从最后10组数据中随机选取1个\n",
" idx = random.randint(len(trigram)-10, len(trigram)-1)\n",
" print('the input words is: ' + trigram[idx][0][0] + ', ' + trigram[idx][0][1])\n",
" x_data = list(map(lambda w: word_to_idx[w], trigram[idx][0]))\n",
" x_data = paddle.imperative.to_variable(np.array(x_data))\n",
" predicts = model(x_data)\n",
" predicts = predicts.numpy().tolist()[0]\n",
" predicts = predicts.index(max(predicts))\n",
" print('the predict words is: ' + idx_to_word[predicts])\n",
" y_data = trigram[idx][1]\n",
" print('the true words is: ' + y_data)\n",
"test(model)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.3 64-bit",
"language": "python",
"name": "python_defaultSpec_1598180286976"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3-final"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
paddle2.0_docs/text_generation/text_generation_paddle.ipynb
0 → 100644
浏览文件 @
6cbf7fd8
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 基于GRU的Text Generation\n",
"文本生成是NLP领域中的重要组成部分,基于GRU,我们可以快速构建文本生成模型。"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'2.0.0-alpha0'"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import paddle\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"paddle.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 复现过程\n",
"## 1.下载数据\n",
"文件路径:https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt\n",
"保存为txt格式即可"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2.读取数据"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Length of text: 1115394 characters\n"
]
}
],
"source": [
"# 文件路径\n",
"path_to_file = './shakespeare.txt'\n",
"text = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n",
"\n",
"# 文本长度是指文本中的字符个数\n",
"print ('Length of text: {} characters'.format(len(text)))"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"First Citizen:\n",
"Before we proceed any further, hear me speak.\n",
"\n",
"All:\n",
"Speak, speak.\n",
"\n",
"First Citizen:\n",
"You are all resolved rather to die than to famish?\n",
"\n",
"All:\n",
"Resolved. resolved.\n",
"\n",
"First Citizen:\n",
"First, you know Caius Marcius is chief enemy to the people.\n",
"\n"
]
}
],
"source": [
"# 看一看文本中的前 250 个字符\n",
"print(text[:250])"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"65 unique characters\n"
]
}
],
"source": [
"# 文本中的非重复字符\n",
"vocab = sorted(set(text))\n",
"print ('{} unique characters'.format(len(vocab)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3.向量化文本\n",
"在训练之前,我们需要将字符串映射到数字表示值。创建两个查找表格:一个将字符映射到数字,另一个将数字映射到字符。"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"# 创建从非重复字符到索引的映射\n",
"char2idx = {u:i for i, u in enumerate(vocab)}\n",
"idx2char = np.array(vocab)\n",
"# 用index表示文本\n",
"text_as_int = np.array([char2idx[c] for c in text])\n"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'\\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, \"'\": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}\n"
]
}
],
"source": [
"print(char2idx)"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['\\n' ' ' '!' '$' '&' \"'\" ',' '-' '.' '3' ':' ';' '?' 'A' 'B' 'C' 'D' 'E'\n",
" 'F' 'G' 'H' 'I' 'J' 'K' 'L' 'M' 'N' 'O' 'P' 'Q' 'R' 'S' 'T' 'U' 'V' 'W'\n",
" 'X' 'Y' 'Z' 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' 'k' 'l' 'm' 'n' 'o'\n",
" 'p' 'q' 'r' 's' 't' 'u' 'v' 'w' 'x' 'y' 'z']\n"
]
}
],
"source": [
"print(idx2char)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"现在,每个字符都有一个整数表示值。请注意,我们将字符映射至索引 0 至 len(vocab)."
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[18 47 56 ... 45 8 0]\n",
"1115394\n"
]
}
],
"source": [
"print(text_as_int)\n",
"print(len(text_as_int))"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"'First Citizen' ---- characters mapped to int ---- > [18 47 56 57 58 1 15 47 58 47 64 43 52]\n"
]
}
],
"source": [
"# 显示文本首 13 个字符的整数映射\n",
"print ('{} ---- characters mapped to int ---- > {}'.format(repr(text[:13]), text_as_int[:13]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 预测任务\n",
"给定一个字符或者一个字符序列,下一个最可能出现的字符是什么?这就是我们训练模型要执行的任务。输入进模型的是一个字符序列,我们训练这个模型来预测输出 -- 每个时间步(time step)预测下一个字符是什么。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 创建训练样本和目标\n",
"接下来,将文本划分为样本序列。每个输入序列包含文本中的 seq_length 个字符。\n",
"\n",
"对于每个输入序列,其对应的目标包含相同长度的文本,但是向右顺移一个字符。\n",
"\n",
"将文本拆分为长度为 seq_length 的文本块。例如,假设 seq_length 为 4 而且文本为 “Hello”, 那么输入序列将为 “Hell”,目标序列将为 “ello”。"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [],
"source": [
"seq_length = 100\n",
"def load_data(data, seq_length):\n",
" train_data = []\n",
" train_label = []\n",
" for i in range(len(data)//seq_length):\n",
" train_data.append(data[i*seq_length:(i+1)*seq_length])\n",
" train_label.append(data[i*seq_length + 1:(i+1)*seq_length+1])\n",
" return train_data, train_label\n",
"train_data, train_label = load_data(text_as_int, seq_length)"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training data is :\n",
"First Citizen:\n",
"Before we proceed any further, hear me speak.\n",
"\n",
"All:\n",
"Speak, speak.\n",
"\n",
"First Citizen:\n",
"You\n",
"------------\n",
"training_label is:\n",
"irst Citizen:\n",
"Before we proceed any further, hear me speak.\n",
"\n",
"All:\n",
"Speak, speak.\n",
"\n",
"First Citizen:\n",
"You \n"
]
}
],
"source": [
"char_list = []\n",
"label_list = []\n",
"for char_id, label_id in zip(train_data[0], train_label[0]):\n",
" char_list.append(idx2char[char_id])\n",
" label_list.append(idx2char[label_id])\n",
"\n",
"print('training data is :')\n",
"print(''.join(char_list))\n",
"print(\"------------\")\n",
"print('training_label is:')\n",
"print(''.join(label_list))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 用`paddle.batch`完成数据的加载"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"batch_size = 64\n",
"def train_reader():\n",
" for i in range(len(train_data)):\n",
" yield train_data[i], train_label[i]\n",
"batch_reader = paddle.batch(train_reader, batch_size=batch_size) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 基于GRU构建文本生成模型"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"import numpy as np\n",
"\n",
"vocab_size = len(vocab)\n",
"embedding_dim = 256\n",
"hidden_size = 1024\n",
"class GRUModel(paddle.nn.Layer):\n",
" def __init__(self):\n",
" super(GRUModel, self).__init__()\n",
" self.embedding = paddle.nn.Embedding(size=[vocab_size, embedding_dim])\n",
" self.gru = paddle.incubate.hapi.text.GRU(input_size=embedding_dim, hidden_size=hidden_size)\n",
" self.linear1 = paddle.nn.Linear(hidden_size, hidden_size//2)\n",
" self.linear2 = paddle.nn.Linear(hidden_size//2, vocab_size)\n",
" def forward(self, x):\n",
" x = self.embedding(x)\n",
" x = paddle.reshape(x, [-1, 1, embedding_dim])\n",
" x, _ = self.gru(x)\n",
" x = paddle.reshape(x, [-1, hidden_size])\n",
" x = self.linear1(x)\n",
" x = paddle.nn.functional.relu(x)\n",
" x = self.linear2(x)\n",
" x = paddle.nn.functional.softmax(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, batch: 50, loss is: [3.7835407]\n",
"epoch: 0, batch: 100, loss is: [3.2774005]\n",
"epoch: 0, batch: 150, loss is: [3.2576294]\n",
"epoch: 1, batch: 50, loss is: [3.3434656]\n",
"epoch: 1, batch: 100, loss is: [2.9948606]\n",
"epoch: 1, batch: 150, loss is: [3.0285468]\n",
"epoch: 2, batch: 50, loss is: [3.133882]\n",
"epoch: 2, batch: 100, loss is: [2.7811327]\n",
"epoch: 2, batch: 150, loss is: [2.8133557]\n",
"epoch: 3, batch: 50, loss is: [3.000814]\n",
"epoch: 3, batch: 100, loss is: [2.6404488]\n",
"epoch: 3, batch: 150, loss is: [2.7050896]\n",
"epoch: 4, batch: 50, loss is: [2.9289591]\n",
"epoch: 4, batch: 100, loss is: [2.5629177]\n",
"epoch: 4, batch: 150, loss is: [2.6438713]\n",
"epoch: 5, batch: 50, loss is: [2.8832304]\n",
"epoch: 5, batch: 100, loss is: [2.5137548]\n",
"epoch: 5, batch: 150, loss is: [2.5926144]\n",
"epoch: 6, batch: 50, loss is: [2.8562953]\n",
"epoch: 6, batch: 100, loss is: [2.4752126]\n",
"epoch: 6, batch: 150, loss is: [2.5510798]\n",
"epoch: 7, batch: 50, loss is: [2.8426895]\n",
"epoch: 7, batch: 100, loss is: [2.4442513]\n",
"epoch: 7, batch: 150, loss is: [2.5187433]\n",
"epoch: 8, batch: 50, loss is: [2.8353484]\n",
"epoch: 8, batch: 100, loss is: [2.4200597]\n",
"epoch: 8, batch: 150, loss is: [2.4956212]\n",
"epoch: 9, batch: 50, loss is: [2.8308532]\n",
"epoch: 9, batch: 100, loss is: [2.4011066]\n",
"epoch: 9, batch: 150, loss is: [2.4787998]\n"
]
}
],
"source": [
"paddle.enable_imperative()\n",
"losses = []\n",
"def train(model):\n",
" model.train()\n",
" optim = paddle.optimizer.SGD(learning_rate=0.001, parameter_list=model.parameters())\n",
" for epoch in range(10):\n",
" batch_id = 0\n",
" for batch_data in batch_reader():\n",
" batch_id += 1\n",
" data = np.array(batch_data)\n",
" x_data = data[:, 0]\n",
" y_data = data[:, 1]\n",
" for i in range(len(x_data[0])):\n",
" x_char = x_data[:, i]\n",
" y_char = y_data[:, i]\n",
" x_char = paddle.imperative.to_variable(x_char)\n",
" y_char = paddle.imperative.to_variable(y_char)\n",
" predicts = model(x_char)\n",
" loss = paddle.nn.functional.cross_entropy(predicts, y_char)\n",
" avg_loss = paddle.mean(loss)\n",
" avg_loss.backward()\n",
" optim.minimize(avg_loss)\n",
" model.clear_gradients()\n",
" if batch_id % 50 == 0:\n",
" print(\"epoch: {}, batch: {}, loss is: {}\".format(epoch, batch_id, avg_loss.numpy()))\n",
" losses.append(loss.numpy())\n",
"model = GRUModel()\n",
"train(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 模型预测\n",
"利用训练好的模型,输出初始化文本'ROMEO: ',自动生成后续的num_generate个字符。"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ROMEO:I the the the the the the the the the the the the the the the the the the the the the the the the th\n"
]
}
],
"source": [
"def generate_text(model, start_string):\n",
" \n",
" model.eval()\n",
" num_generate = 100\n",
"\n",
" # Converting our start string to numbers (vectorizing)\n",
" input_eval = [char2idx[s] for s in start_string]\n",
" input_data = paddle.imperative.to_variable(np.array(input_eval))\n",
" input_data = paddle.reshape(input_data, [-1, 1])\n",
" text_generated = []\n",
"\n",
" for i in range(num_generate):\n",
" predicts = model(input_data)\n",
" predicts = predicts.numpy().tolist()[0]\n",
" # print(predicts)\n",
" predicts_id = predicts.index(max(predicts))\n",
" # print(predicts_id)\n",
" # using a categorical distribution to predict the character returned by the model\n",
" input_data = paddle.imperative.to_variable(np.array([predicts_id]))\n",
" input_data = paddle.reshape(input_data, [-1, 1])\n",
" text_generated.append(idx2char[predicts_id])\n",
" return (start_string + ''.join(text_generated))\n",
"print(generate_text(model, start_string=u\"ROMEO:\"))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录