{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 5.5 卷积神经网络(LeNet)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2019-05-29T13:57:37.383972Z", "start_time": "2019-05-29T13:57:34.520559Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0.0\n", "cuda\n" ] } ], "source": [ "import os\n", "import time\n", "import torch\n", "from torch import nn, optim\n", "\n", "import sys\n", "sys.path.append(\"..\") \n", "import d2lzh_pytorch as d2l\n", "\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "print(torch.__version__)\n", "print(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5.5.1 LeNet模型 " ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2019-05-29T13:57:37.394997Z", "start_time": "2019-05-29T13:57:37.386720Z" } }, "outputs": [], "source": [ "class LeNet(nn.Module):\n", " def __init__(self):\n", " super(LeNet, self).__init__()\n", " self.conv = nn.Sequential(\n", " nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size\n", " nn.Sigmoid(),\n", " nn.MaxPool2d(2, 2), # kernel_size, stride\n", " nn.Conv2d(6, 16, 5),\n", " nn.Sigmoid(),\n", " nn.MaxPool2d(2, 2)\n", " )\n", " self.fc = nn.Sequential(\n", " nn.Linear(16*4*4, 120),\n", " nn.Sigmoid(),\n", " nn.Linear(120, 84),\n", " nn.Sigmoid(),\n", " nn.Linear(84, 10)\n", " )\n", "\n", " def forward(self, img):\n", " feature = self.conv(img)\n", " output = self.fc(feature.view(img.shape[0], -1))\n", " return output" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2019-05-29T13:57:37.450484Z", "start_time": "2019-05-29T13:57:37.397357Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LeNet(\n", " (conv): Sequential(\n", " (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n", " (1): Sigmoid()\n", " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n", " (4): Sigmoid()\n", " (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", " (fc): Sequential(\n", " (0): Linear(in_features=256, out_features=120, bias=True)\n", " (1): Sigmoid()\n", " (2): Linear(in_features=120, out_features=84, bias=True)\n", " (3): Sigmoid()\n", " (4): Linear(in_features=84, out_features=10, bias=True)\n", " )\n", ")\n" ] } ], "source": [ "net = LeNet()\n", "print(net)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5.5.2 获取数据和训练模型" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2019-05-29T13:57:38.432567Z", "start_time": "2019-05-29T13:57:37.452521Z" } }, "outputs": [], "source": [ "batch_size = 256\n", "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2019-05-29T13:57:38.442887Z", "start_time": "2019-05-29T13:57:38.435111Z" } }, "outputs": [], "source": [ "# 本函数已保存在d2lzh_pytorch包中方便以后使用。该函数将被逐步改进:它的完整实现将在“图像增广”一节中描述\n", "def evaluate_accuracy(data_iter, net, device=None):\n", " if device is None and isinstance(net, torch.nn.Module):\n", " # 如果没指定device就使用net的device\n", " device = list(net.parameters())[0].device\n", " acc_sum, n = 0.0, 0\n", " with torch.no_grad():\n", " for X, y in data_iter:\n", " if isinstance(net, torch.nn.Module):\n", " net.eval() # 评估模式, 这会关闭dropout\n", " acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()\n", " net.train() # 改回训练模式\n", " else: # 自定义的模型, 3.13节之后不会用到, 不考虑GPU\n", " if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数\n", " # 将is_training设置成False\n", " acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() \n", " else:\n", " acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() \n", " n += y.shape[0]\n", " return acc_sum / n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2019-05-29T13:57:38.453480Z", "start_time": "2019-05-29T13:57:38.445655Z" } }, "outputs": [], "source": [ "# 本函数已保存在d2lzh_pytorch包中方便以后使用\n", "def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):\n", " net = net.to(device)\n", " print(\"training on \", device)\n", " loss = torch.nn.CrossEntropyLoss()\n", " batch_count = 0\n", " for epoch in range(num_epochs):\n", " train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()\n", " for X, y in train_iter:\n", " X = X.to(device)\n", " y = y.to(device)\n", " y_hat = net(X)\n", " l = loss(y_hat, y)\n", " optimizer.zero_grad()\n", " l.backward()\n", " optimizer.step()\n", " train_l_sum += l.cpu().item()\n", " train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()\n", " n += y.shape[0]\n", " batch_count += 1\n", " test_acc = evaluate_accuracy(test_iter, net)\n", " print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'\n", " % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2019-05-29T13:58:00.333237Z", "start_time": "2019-05-29T13:57:38.456012Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "training on cuda\n", "epoch 1, loss 1.7885, train acc 0.337, test acc 0.584, time 2.4 sec\n", "epoch 2, loss 0.4793, train acc 0.614, test acc 0.666, time 2.3 sec\n", "epoch 3, loss 0.2637, train acc 0.704, test acc 0.720, time 2.3 sec\n", "epoch 4, loss 0.1747, train acc 0.734, test acc 0.740, time 2.2 sec\n", "epoch 5, loss 0.1282, train acc 0.751, test acc 0.749, time 2.2 sec\n" ] } ], "source": [ "lr, num_epochs = 0.001, 5\n", "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n", "train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }