提交 b1786496 编写于 作者: W wanghaoshuang

Add demo of DML for pytorch

上级 c1d7dbf3
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. 安装依赖\n",
"\n",
"### 1.1 安装PaddleSlim\n",
"\n",
"```\n",
"git clone https://github.com/PaddlePaddle/PaddleSlim.git\n",
"cd PaddleSlim\n",
"python setup.py install\n",
"```\n",
"\n",
"### 1.2 安装pytorch\n",
"\n",
"```\n",
"pip install torch torchvision\n",
"```\n",
"\n",
"## 2. Import依赖与环境设置"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from __future__ import print_function\n",
"import argparse\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torchvision import datasets, transforms, models\n",
"from torch.optim.lr_scheduler import StepLR\n",
"from paddleslim.dist import DML\n",
"\n",
"args = {\"batch-size\": 256,\n",
" \"test-batch-size\": 256,\n",
" \"epochs\": 10,\n",
" \"lr\": 1.0,\n",
" \"gamma\": 0.7,\n",
" \"seed\": 1,\n",
" \"log-interval\": 10}\n",
"\n",
"\n",
"\n",
"use_cuda = torch.cuda.is_available()\n",
"torch.manual_seed(args[\"seed\"])\n",
"device = torch.device(\"cuda\" if use_cuda else \"cpu\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. 准备数据\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"ename": "ImportError",
"evalue": "FloatProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html",
"output_type": "error",
"traceback": [
"\u001b[0;31m\u001b[0m",
"\u001b[0;31mImportError\u001b[0mTraceback (most recent call last)",
"\u001b[0;32m<ipython-input-3-1641ec60d682>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 6\u001b[0m transform=transforms.Compose([\n\u001b[1;32m 7\u001b[0m \u001b[0mtransforms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mToTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mtransforms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNormalize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m0.5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m ])),\n\u001b[1;32m 10\u001b[0m batch_size=args[\"batch_size\"], shuffle=True, **kwargs)\n",
"\u001b[0;32m/root/envs/paddle_1.8/lib/python2.7/site-packages/torchvision/datasets/cifar.pyc\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root, train, transform, target_transform, download)\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 64\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdownload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 65\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_integrity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/root/envs/paddle_1.8/lib/python2.7/site-packages/torchvision/datasets/cifar.pyc\u001b[0m in \u001b[0;36mdownload\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[0;32mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Files already downloaded and verified'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 148\u001b[0;31m \u001b[0mdownload_and_extract_archive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroot\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtgz_md5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 149\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/root/envs/paddle_1.8/lib/python2.7/site-packages/torchvision/datasets/utils.pyc\u001b[0m in \u001b[0;36mdownload_and_extract_archive\u001b[0;34m(url, download_root, extract_root, filename, md5, remove_finished)\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0mfilename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbasename\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 263\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 264\u001b[0;31m \u001b[0mdownload_url\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdownload_root\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 265\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 266\u001b[0m \u001b[0marchive\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdownload_root\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/root/envs/paddle_1.8/lib/python2.7/site-packages/torchvision/datasets/utils.pyc\u001b[0m in \u001b[0;36mdownload_url\u001b[0;34m(url, root, filename, md5)\u001b[0m\n\u001b[1;32m 83\u001b[0m urllib.request.urlretrieve(\n\u001b[1;32m 84\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfpath\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 85\u001b[0;31m \u001b[0mreporthook\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgen_bar_updater\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 86\u001b[0m )\n\u001b[1;32m 87\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0murllib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mURLError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mIOError\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/root/envs/paddle_1.8/lib/python2.7/site-packages/torchvision/datasets/utils.pyc\u001b[0m in \u001b[0;36mgen_bar_updater\u001b[0;34m()\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgen_bar_updater\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mpbar\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtotal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbar_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcount\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mblock_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtotal_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/root/envs/paddle_1.8/lib/python2.7/site-packages/tqdm/notebook.pyc\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[0mtotal\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtotal\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0munit_scale\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtotal\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtotal\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 208\u001b[0m self.container = self.status_printer(\n\u001b[0;32m--> 209\u001b[0;31m self.fp, total, self.desc, self.ncols)\n\u001b[0m\u001b[1;32m 210\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisplay\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/root/envs/paddle_1.8/lib/python2.7/site-packages/tqdm/notebook.pyc\u001b[0m in \u001b[0;36mstatus_printer\u001b[0;34m(_, total, desc, ncols)\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0;31m# #187 #451 #558\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 103\u001b[0m raise ImportError(\n\u001b[0;32m--> 104\u001b[0;31m \u001b[0;34m\"FloatProgress not found. Please update jupyter and ipywidgets.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 105\u001b[0m \u001b[0;34m\" See https://ipywidgets.readthedocs.io/en/stable\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 106\u001b[0m \"/user_install.html\")\n",
"\u001b[0;31mImportError\u001b[0m: FloatProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html"
]
}
],
"source": [
"\n",
"\n",
"kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}\n",
"train_loader = torch.utils.data.DataLoader(\n",
" datasets.CIFAR10('../data', train=True, download=True,\n",
" transform=transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
" ])),\n",
" batch_size=args[\"batch_size\"], shuffle=True, **kwargs)\n",
"test_loader = torch.utils.data.DataLoader(\n",
" datasets.CIFAR10('../data', train=False, transform=transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
" ])),\n",
" batch_size=args[\"test_batch_size\"], shuffle=True, **kwargs)\n",
"\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. 定义模型"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = models.mobilenet_v2(num_classes=10).to(device)\n",
"optimizer = optim.Adadelta(model.parameters(), lr=args.lr)\n",
"scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. 添加DML修饰\n",
"### 5.1 将模型转为DML模型"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = DML(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5.2 将优化器转为DML优化器"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"optimizer = model.opt(optimizer)\n",
"scheduler = model.lr(scheduler)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 6. 定义训练方法\n",
"\n",
"将原来的交叉熵损失替换为DML损失,代码如下:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train(args, model, device, train_loader, optimizer, epoch):\n",
" model.train()\n",
" for batch_idx, (data, target) in enumerate(train_loader):\n",
" data, target = data.to(device), target.to(device)\n",
" optimizer.zero_grad()\n",
" output = model(data)\n",
" loss = model.dml_loss(output, target) \n",
"# output = F.softmax(output, dim=1)\n",
"# loss = F.cross_entropy(output, target)\n",
"# loss.backward()\n",
" optimizer.step()\n",
" if batch_idx % args[\"log_interval\"] == 0:\n",
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
" epoch, batch_idx * len(data), len(train_loader.dataset),\n",
" 100. * batch_idx / len(train_loader), loss.item()))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. 定义测试方法"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def test(model, device, test_loader):\n",
" model.eval()\n",
" test_loss = 0\n",
" correct = 0\n",
" with torch.no_grad():\n",
" for data, target in test_loader:\n",
" data, target = data.to(device), target.to(device)\n",
" output = model(data)\n",
" output = F.softmax(output, dim=1)\n",
" loss = F.cross_entropy(output, target, reduction=\"sum\")\n",
" test_loss += loss\n",
" pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n",
" correct += pred.eq(target.view_as(pred)).sum().item()\n",
"\n",
" test_loss /= len(test_loader.dataset)\n",
"\n",
" print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
" test_loss, correct, len(test_loader.dataset),\n",
" 100. * correct / len(test_loader.dataset))) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. 开始训练"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"epochs = 10\n",
"for epoch in range(1, epochs + 1):\n",
" train(args, model, device, train_loader, optimizer, epoch)\n",
" test(model, device, test_loader)\n",
" scheduler.step()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import StepLR
from paddleslim.dist import DML
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = model.dml_loss(output, target) # Addddddd
# output = F.softmax(output, dim=1)
# loss = F.cross_entropy(output, target)
# loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data),
len(train_loader.dataset), 100. * batch_idx / len(
train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
output = F.softmax(output, dim=1)
loss = F.cross_entropy(output, target, reduction="sum")
test_loss += loss
pred = output.argmax(
dim=1,
keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.
format(test_loss, correct,
len(test_loader.dataset), 100. * correct / len(
test_loader.dataset)))
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument(
'--batch-size',
type=int,
default=256,
metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument(
'--test-batch-size',
type=int,
default=1000,
metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument(
'--epochs',
type=int,
default=14,
metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument(
'--lr',
type=float,
default=1.0,
metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument(
'--gamma',
type=float,
default=0.7,
metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument(
'--no-cuda',
action='store_true',
default=False,
help='disables CUDA training')
parser.add_argument(
'--seed',
type=int,
default=1,
metavar='S',
help='random seed (default: 1)')
parser.add_argument(
'--log-interval',
type=int,
default=10,
metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument(
'--save-model',
action='store_true',
default=False,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(
'../data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))
])),
batch_size=args.batch_size,
shuffle=True,
**kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(
'../data',
train=False,
transform=transforms.Compose([
transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))
])),
batch_size=args.test_batch_size,
shuffle=True,
**kwargs)
model = models.mobilenet_v2(num_classes=10).to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
model = DML(model) # Adddddddddd
optimizer = model.opt(optimizer) # Adddddddddd
scheduler = model.lr(scheduler)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
scheduler.step()
if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")
if __name__ == '__main__':
main()
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import paddle.fluid as fluid
class PD_DML(fluid.dygraph.Layer):
def __init__(self, model, use_parallel):
super(PD_DML, self).__init__()
self.model = model
self.use_parallel = use_parallel
self.model_num = len(self.model)
if self.use_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
self.model = [
fluid.dygraph.parallel.DataParallel(m, strategy)
for m in self.model
]
def full_name(self):
return [m.full_name() for m in self.model]
def forward(self, input):
return [m(input) for m in self.model]
def opt(self, optimizer):
assert len(
optimizer
) == self.model_num, "The number of optimizers must match the number of models"
optimizer = DMLOptimizers(self.model, optimizer, self.use_parallel)
return optimizer
def ce_loss(self, logits, labels):
assert len(
logits
) == self.model_num, "The number of logits must match the number of models"
ce_losses = []
for i in range(self.model_num):
ce_losses.append(
fluid.layers.mean(
fluid.layers.softmax_with_cross_entropy(logits[i],
labels)))
return ce_losses
def kl_loss(self, logits):
assert len(
logits
) == self.model_num, "The number of logits must match the number of models"
if self.model_num == 1:
return []
kl_losses = []
for i in range(self.model_num):
cur_kl_loss = 0
for j in range(self.model_num):
if i != j:
x = fluid.layers.log_softmax(logits[i], axis=1)
y = fluid.layers.softmax(logits[j], axis=1)
cur_kl_loss += fluid.layers.kldiv_loss(
x, y, reduction='batchmean')
kl_losses.append(cur_kl_loss / (self.model_num - 1))
return kl_losses
def loss(self, logits, labels):
gt_losses = self.ce_loss(logits, labels)
kl_losses = self.kl_loss(logits)
if self.model_num > 1:
return [a + b for a, b in zip(gt_losses, kl_losses)]
else:
return gt_losses
def acc(self, logits, labels, k):
accs = [
fluid.layers.accuracy(
input=l, label=labels, k=k) for l in logits
]
return accs
def train(self):
for m in self.model:
m.train()
def eval(self):
for m in self.model:
m.eval()
class DMLOptimizers(object):
def __init__(self, model, optimizer, use_parallel):
self.model = model
self.optimizer = optimizer
self.use_parallel = use_parallel
def minimize(self, losses):
assert len(losses) == len(
self.optimizer
), "The number of losses must match the number of optimizers"
for i in range(len(losses)):
if self.use_parallel:
losses[i] = self.model[i].scale_loss(losses[i])
losses[i].backward()
self.model[i].apply_collective_grads()
else:
losses[i].backward()
self.optimizer[i].minimize(losses[i])
self.model[i].clear_gradients()
def get_lr(self):
current_step_lr = [opt.current_step_lr() for opt in self.optimizer]
return current_step_lr
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import weakref
from functools import wraps
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torch.optim import Optimizer
class TORCH_DML(nn.Module):
def __init__(self, model):
super(TORCH_DML, self).__init__()
self.model = model
self.parnters = [model]
device = next(self.model.parameters()).device
self.parnters.append(type(self.model)(num_classes=10).to(device))
# self.parnters.append(copy.deepcopy(self.model).to(device))
def forward(self, input):
if self.model.training:
return [m(input) for m in self.parnters]
else:
return self.model(input)
def opt(self, optimizer):
optimizers = []
for parnter in self.parnters:
new_opt = copy.deepcopy(optimizer)
new_opt.param_groups = []
new_opt.add_param_group({"params": parnter.parameters()})
optimizers.append(new_opt)
self.optimizer = DMLOptimizer(optimizers, self)
return self.optimizer
def _clone_scheduler(self, scheduler, optimizer):
scheduler = copy.deepcopy(scheduler)
# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(
type(optimizer).__name__))
scheduler.optimizer = optimizer
scheduler.last_epoch -= 1
# Initialize epoch and base learning rates
if scheduler.last_epoch == -1:
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
else:
for i, group in enumerate(optimizer.param_groups):
if 'initial_lr' not in group:
raise KeyError(
"param 'initial_lr' is not specified "
"in param_groups[{}] when resuming an optimizer".
format(i))
scheduler.base_lrs = list(
map(lambda group: group['initial_lr'], optimizer.param_groups))
# Following https://github.com/pytorch/pytorch/issues/20124
# We would like to ensure that `lr_scheduler.step()` is called after
# `optimizer.step()`
def with_counter(method):
if getattr(method, '_with_counter', False):
# `optimizer.step()` has already been replaced, return.
return method
# Keep a weak reference to the optimizer instance to prevent
# cyclic references.
instance_ref = weakref.ref(method.__self__)
# Get the unbound method for the same purpose.
func = method.__func__
cls = instance_ref().__class__
del method
@wraps(func)
def wrapper(*args, **kwargs):
instance = instance_ref()
instance._step_count += 1
wrapped = func.__get__(instance, cls)
return wrapped(*args, **kwargs)
# Note that the returned function here is no longer a bound method,
# so attributes like `__func__` and `__self__` no longer exist.
wrapper._with_counter = True
return wrapper
scheduler.optimizer.step = with_counter(scheduler.optimizer.step)
scheduler.optimizer._step_count = 0
scheduler._step_count = 0
scheduler.step()
return scheduler
def lr(self, scheduler):
schedulers = []
for opt in self.optimizer.optimizers:
new_scheduler = self._clone_scheduler(scheduler, opt)
schedulers.append(new_scheduler)
self.scheduler = DMLScheduler(schedulers, self)
return self.scheduler
def nll_loss(self, logit, label):
logit = F.softmax(logit, dim=1)
return F.cross_entropy(logit, label)
def kl_loss(self, logit0, logit1):
logit0 = F.log_softmax(logit0, dim=1)
logit1 = F.softmax(logit1, dim=1)
return F.kl_div(logit0, logit1, reduction='batchmean')
def dml_loss(self, logits, label, gt_loss_func=None, dist_loss_func=None):
gt_loss_func = gt_loss_func if gt_loss_func is not None else self.nll_loss
dist_loss_func = dist_loss_func if dist_loss_func is not None else self.kl_loss
self.losses = []
for i in range(len(logits)):
logit = logits[i]
cur_loss = gt_loss_func(logit, label)
for j in range(len(logits)):
if i != j:
dist_loss = dist_loss_func(logits[i], logits[j])
cur_loss += dist_loss
self.losses.append(cur_loss)
return self.losses[0]
def train(self):
self.model.train()
def eval(self):
self.model.eval()
class DMLOptimizer(object):
def __init__(self, optimizers, model):
self.optimizers = optimizers
self.model = model
def step(self):
for loss, optimizer in zip(self.model.losses, self.optimizers):
loss.backward(retain_graph=True)
optimizer.step()
def zero_grad(self):
for optimizer in self.optimizers:
optimizer.zero_grad()
class DMLScheduler(object):
def __init__(self, schedulers, model):
self.schedulers = schedulers
self.model = model
def step(self):
for scheduler in self.schedulers:
scheduler.step()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册