Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
VisualDL
提交
2ff30dd9
V
VisualDL
项目概览
PaddlePaddle
/
VisualDL
1 年多 前同步成功
通知
88
Star
4655
Fork
642
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
2
Wiki
5
Wiki
分析
仓库
DevOps
项目成员
Pages
V
VisualDL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
2
合并请求
2
Pages
分析
分析
仓库分析
DevOps
Wiki
5
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2ff30dd9
编写于
1月 31, 2018
作者:
D
daminglu
提交者:
GitHub
1月 31, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Demo of Using VisualDL in PyTorch (#230)
上级
a601714e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
630 addition
and
0 deletion
+630
-0
demo/pytorch/TUTORIAL_CN.md
demo/pytorch/TUTORIAL_CN.md
+211
-0
demo/pytorch/pytorch_cifar10.ipynb
demo/pytorch/pytorch_cifar10.ipynb
+271
-0
demo/pytorch/pytorch_cifar10.py
demo/pytorch/pytorch_cifar10.py
+148
-0
未找到文件。
demo/pytorch/TUTORIAL_CN.md
0 → 100644
浏览文件 @
2ff30dd9
# 如何在PyTorch中使用VisualDL
下面我们演示一下如何在PyTorch中使用VisualDL,从而可以把PyTorch的训练过程以及最后的模型可视化出来。我们将以PyTorch用卷积神经网络(CNN, Convolutional Neural Network)来训练
[
Cifar10
](
https://www.cs.toronto.edu/~kriz/cifar.html
)
数据集作为例子。
程序的主体来自PyTorch的
[
Tutorial
](
http://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
)
我们同时提供了 Jupyter Notebook 的可交互版本。请参见本文件夹里面的 pytorch_cifar10.ipynb
```
python
import
torch
import
torchvision
import
torchvision.transforms
as
transforms
from
torch.autograd
import
Variable
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
matplotlib
matplotlib
.
use
(
'Agg'
)
from
visualdl
import
LogWriter
transform
=
transforms
.
Compose
(
[
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.5
,
0.5
,
0.5
),
(
0.5
,
0.5
,
0.5
))])
trainset
=
torchvision
.
datasets
.
CIFAR10
(
root
=
'./data'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
trainloader
=
torch
.
utils
.
data
.
DataLoader
(
trainset
,
batch_size
=
500
,
shuffle
=
True
,
num_workers
=
2
)
testset
=
torchvision
.
datasets
.
CIFAR10
(
root
=
'./data'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
testloader
=
torch
.
utils
.
data
.
DataLoader
(
testset
,
batch_size
=
500
,
shuffle
=
False
,
num_workers
=
2
)
classes
=
(
'plane'
,
'car'
,
'bird'
,
'cat'
,
'deer'
,
'dog'
,
'frog'
,
'horse'
,
'ship'
,
'truck'
)
import
matplotlib.pyplot
as
plt
import
numpy
as
np
# functions to show an image
def
imshow
(
img
):
img
=
img
/
2
+
0.5
# unnormalize
npimg
=
img
.
numpy
()
fig
,
ax
=
plt
.
subplots
()
plt
.
imshow
(
np
.
transpose
(
npimg
,
(
1
,
2
,
0
)))
# we can either show the image or save it locally
# plt.show()
fig
.
savefig
(
'out'
+
str
(
np
.
random
.
randint
(
0
,
10000
))
+
'.pdf'
)
```
我们可以预览一下将要分析的 Cifar10 图片集:
<p
align=
center
>
<img
width=
"70%"
src=
"https://github.com/daming-lu/large_files/blob/master/pytorch_demo_figs/pytorch_cifar10_show_image.png?raw=true"
/>
</p>
然后我们开始创建 VisualDL 的数据采集 loggers
```
python
logdir
=
"/workspace"
logger
=
LogWriter
(
logdir
,
sync_cycle
=
100
)
# mark the components with 'train' label.
with
logger
.
mode
(
"train"
):
# create a scalar component called 'scalars/'
scalar_pytorch_train_loss
=
logger
.
scalar
(
"scalars/scalar_pytorch_train_loss"
)
image1
=
logger
.
image
(
"images/image1"
,
1
)
image2
=
logger
.
image
(
"images/image2"
,
1
)
histogram0
=
logger
.
histogram
(
"histogram/histogram0"
,
num_buckets
=
100
)
```
Cifar10 中有 50000 个训练图像和 10000 个测试图像。我们每 500 个作为一个训练集,图片采样也选 500 。 每个训练集 (batch) 是如下的维度:
500 x 3 x 32 x 32
接下来我们开始创建 CNN 模型
```
python
# get some random training images
dataiter
=
iter
(
trainloader
)
images
,
labels
=
dataiter
.
next
()
# show images
imshow
(
torchvision
.
utils
.
make_grid
(
images
))
# print labels
print
(
' '
.
join
(
'%5s'
%
classes
[
labels
[
j
]]
for
j
in
range
(
4
)))
# Define a Convolution Neural Network
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
3
,
6
,
5
)
self
.
pool
=
nn
.
MaxPool2d
(
2
,
2
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
)
self
.
fc1
=
nn
.
Linear
(
16
*
5
*
5
,
120
)
self
.
fc2
=
nn
.
Linear
(
120
,
84
)
self
.
fc3
=
nn
.
Linear
(
84
,
10
)
def
forward
(
self
,
x
):
x
=
self
.
pool
(
F
.
relu
(
self
.
conv1
(
x
)))
x
=
self
.
pool
(
F
.
relu
(
self
.
conv2
(
x
)))
x
=
x
.
view
(
-
1
,
16
*
5
*
5
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
F
.
relu
(
self
.
fc2
(
x
))
x
=
self
.
fc3
(
x
)
return
x
net
=
Net
()
```
接下来我们开始训练并且同时用 VisualDL 来采集相关数据
```
python
# Train the network
for
epoch
in
range
(
5
):
# loop over the dataset multiple times
running_loss
=
0.0
for
i
,
data
in
enumerate
(
trainloader
,
0
):
# get the inputs
inputs
,
labels
=
data
# wrap them in Variable
inputs
,
labels
=
Variable
(
inputs
),
Variable
(
labels
)
# zero the parameter gradients
optimizer
.
zero_grad
()
# forward + backward + optimize
outputs
=
net
(
inputs
)
loss
=
criterion
(
outputs
,
labels
)
loss
.
backward
()
optimizer
.
step
()
# use VisualDL to retrieve metrics
# scalar
scalar_pytorch_train_loss
.
add_record
(
train_step
,
float
(
loss
))
# histogram
weight_list
=
net
.
conv1
.
weight
.
view
(
6
*
3
*
5
*
5
,
-
1
)
histogram0
.
add_record
(
train_step
,
weight_list
)
# image
image1
.
start_sampling
()
image1
.
add_sample
([
96
,
25
],
net
.
conv2
.
weight
.
view
(
16
*
6
*
5
*
5
,
-
1
))
image1
.
finish_sampling
()
image2
.
start_sampling
()
image2
.
add_sample
([
18
,
25
],
net
.
conv1
.
weight
.
view
(
6
*
3
*
5
*
5
,
-
1
))
image2
.
finish_sampling
()
train_step
+=
1
# print statistics
running_loss
+=
loss
.
data
[
0
]
if
i
%
2000
==
1999
:
# print every 2000 mini-batches
print
(
'[%d, %5d] loss: %.3f'
%
(
epoch
+
1
,
i
+
1
,
running_loss
/
2000
))
running_loss
=
0.0
print
(
'Finished Training'
)
```
最后,因为 PyTorch 采用 Dynamic Computation Graphs,我们用一个 dummy 输入来空跑一下模型,以便产生图
```
python
import
torch.onnx
dummy_input
=
Variable
(
torch
.
randn
(
4
,
3
,
32
,
32
))
torch
.
onnx
.
export
(
net
,
dummy_input
,
"pytorch_cifar10.onnx"
)
print
(
'Done'
)
```
训练结束后,各个组件的可视化结果如下:
关于误差的数值图的如下:
<p
align=
center
>
<img
width=
"70%"
src=
"https://github.com/daming-lu/large_files/blob/master/pytorch_demo_figs/sc_scalar.png?raw=true"
/>
</p>
训练过后的第一,第二层卷积权重图的如下:
<p
align=
center
>
<img
width=
"70%"
src=
"https://github.com/daming-lu/large_files/blob/master/pytorch_demo_figs/sc_image.png?raw=true"
/>
</p>
训练参数的柱状图的如下:
<p
align=
center
>
<img
width=
"70%"
src=
"https://github.com/daming-lu/large_files/blob/master/pytorch_demo_figs/sc_hist.png?raw=true"
/>
</p>
模型图的效果如下:
<p
align=
center
>
<img
width=
"70%"
src=
"https://github.com/daming-lu/large_files/blob/master/pytorch_demo_figs/sc_graph.png?raw=true"
/>
</p>
生成的完整效果图可以在
[
这里
](
https://github.com/daming-lu/large_files/blob/master/pytorch_demo_figs/graph.png?raw=true
)
下载。
demo/pytorch/pytorch_cifar10.ipynb
0 → 100644
浏览文件 @
2ff30dd9
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"如何在PyTorch中使用VisualDL\n",
"=====================\n",
"\n",
"下面我们演示一下如何在PyTorch中使用VisualDL,从而可以把PyTorch的训练过程以及最后的模型可视化出来。我们将以PyTorch用卷积神经网络(CNN, Convolutional Neural Network)来训练 [Cifar10](https://www.cs.toronto.edu/~kriz/cifar.html) 数据集作为例子。\n",
"\n",
"程序的主体来自PyTorch的 [Tutorial](http://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import torch\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"from torch.autograd import Variable\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"\n",
"import matplotlib\n",
"matplotlib.use('Agg')\n",
"\n",
"from visualdl import LogWriter\n",
"\n",
"\n",
"transform = transforms.Compose(\n",
" [transforms.ToTensor(),\n",
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
"\n",
"trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n",
" download=True, transform=transform)\n",
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=500,\n",
" shuffle=True, num_workers=2)\n",
"\n",
"testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n",
" download=True, transform=transform)\n",
"testloader = torch.utils.data.DataLoader(testset, batch_size=500,\n",
" shuffle=False, num_workers=2)\n",
"\n",
"classes = ('plane', 'car', 'bird', 'cat',\n",
" 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
"\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"\n",
"# functions to show an image\n",
"def imshow(img):\n",
" img = img / 2 + 0.5 # unnormalize\n",
" npimg = img.numpy()\n",
" fig, ax = plt.subplots()\n",
" plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
" # we can either show the image or save it locally\n",
" # plt.show()\n",
" fig.savefig('out' + str(np.random.randint(0, 10000)) + '.pdf')\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"然后我们开始创建 VisualDL 的数据采集 loggers\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"logdir = \"/workspace\"\n",
"logger = LogWriter(logdir, sync_cycle=100)\n",
"\n",
"# mark the components with 'train' label.\n",
"with logger.mode(\"train\"):\n",
" # create a scalar component called 'scalars/'\n",
" scalar_pytorch_train_loss = logger.scalar(\"scalars/scalar_pytorch_train_loss\")\n",
" image1 = logger.image(\"images/image1\", 1)\n",
" image2 = logger.image(\"images/image2\", 1)\n",
" histogram0 = logger.histogram(\"histogram/histogram0\", num_buckets=100)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Cifar10 中有 50000 个训练图像和 10000 个测试图像。我们每 500 个作为一个训练集,图片采样也选 500 。 每个训练集 (batch) 是如下的维度:\n",
"\n",
"500 x 3 x 32 x 32\n",
"\n",
"接下来我们开始创建 CNN 模型\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# get some random training images\n",
"dataiter = iter(trainloader)\n",
"images, labels = dataiter.next()\n",
"\n",
"# show images\n",
"imshow(torchvision.utils.make_grid(images))\n",
"# print labels\n",
"print(' '.join('%5s' % classes[labels[j]] for j in range(4)))\n",
"\n",
"# Define a Convolution Neural Network\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.conv1 = nn.Conv2d(3, 6, 5)\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
" self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
" self.fc2 = nn.Linear(120, 84)\n",
" self.fc3 = nn.Linear(84, 10)\n",
"\n",
" def forward(self, x):\n",
" x = self.pool(F.relu(self.conv1(x)))\n",
" x = self.pool(F.relu(self.conv2(x)))\n",
" x = x.view(-1, 16 * 5 * 5)\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" return x\n",
"\n",
"\n",
"net = Net()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"接下来我们开始训练并且同时用 VisualDL 来采集相关数据\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Train the network\n",
"for epoch in range(5): # loop over the dataset multiple times\n",
" running_loss = 0.0\n",
" for i, data in enumerate(trainloader, 0):\n",
" # get the inputs\n",
" inputs, labels = data\n",
"\n",
" # wrap them in Variable\n",
" inputs, labels = Variable(inputs), Variable(labels)\n",
"\n",
" # zero the parameter gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # forward + backward + optimize\n",
" outputs = net(inputs)\n",
" loss = criterion(outputs, labels)\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # use VisualDL to retrieve metrics\n",
" # scalar\n",
" scalar_pytorch_train_loss.add_record(train_step, float(loss))\n",
"\n",
" # histogram\n",
" weight_list = net.conv1.weight.view(6*3*5*5, -1)\n",
" histogram0.add_record(train_step, weight_list)\n",
"\n",
" # image\n",
" image1.start_sampling()\n",
" image1.add_sample([96, 25], net.conv2.weight.view(16*6*5*5, -1))\n",
" image1.finish_sampling()\n",
"\n",
" image2.start_sampling()\n",
" image2.add_sample([18, 25], net.conv1.weight.view(6*3*5*5, -1))\n",
" image2.finish_sampling()\n",
"\n",
"\n",
" train_step += 1\n",
"\n",
" # print statistics\n",
" running_loss += loss.data[0]\n",
" if i % 2000 == 1999: # print every 2000 mini-batches\n",
" print('[%d, %5d] loss: %.3f' %\n",
" (epoch + 1, i + 1, running_loss / 2000))\n",
" running_loss = 0.0\n",
"\n",
"print('Finished Training')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"最后,因为 PyTorch 采用 Dynamic Computation Graphs,我们用一个 dummy 输入来空跑一下模型,以便产生图"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import torch.onnx\n",
"dummy_input = Variable(torch.randn(4, 3, 32, 32))\n",
"torch.onnx.export(net, dummy_input, \"pytorch_cifar10.onnx\")\n",
"\n",
"print('Done')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.14"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
demo/pytorch/pytorch_cifar10.py
0 → 100644
浏览文件 @
2ff30dd9
import
torch
import
torchvision
import
torchvision.transforms
as
transforms
from
torch.autograd
import
Variable
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.onnx
import
matplotlib
from
visualdl
import
LogWriter
import
matplotlib.pyplot
as
plt
import
numpy
as
np
matplotlib
.
use
(
'Agg'
)
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.5
,
0.5
,
0.5
),
(
0.5
,
0.5
,
0.5
))
])
trainset
=
torchvision
.
datasets
.
CIFAR10
(
root
=
'./data'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
trainloader
=
torch
.
utils
.
data
.
DataLoader
(
trainset
,
batch_size
=
500
,
shuffle
=
True
,
num_workers
=
2
)
testset
=
torchvision
.
datasets
.
CIFAR10
(
root
=
'./data'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
testloader
=
torch
.
utils
.
data
.
DataLoader
(
testset
,
batch_size
=
500
,
shuffle
=
False
,
num_workers
=
2
)
classes
=
(
'plane'
,
'car'
,
'bird'
,
'cat'
,
'deer'
,
'dog'
,
'frog'
,
'horse'
,
'ship'
,
'truck'
)
# functions to show an image
def
imshow
(
img
):
img
=
img
/
2
+
0.5
# unnormalize
npimg
=
img
.
numpy
()
fig
,
ax
=
plt
.
subplots
()
plt
.
imshow
(
np
.
transpose
(
npimg
,
(
1
,
2
,
0
)))
# we can either show the image or save it locally
# plt.show()
fig
.
savefig
(
'out'
+
str
(
np
.
random
.
randint
(
0
,
10000
))
+
'.pdf'
)
logdir
=
"/workspace"
logger
=
LogWriter
(
logdir
,
sync_cycle
=
100
)
# mark the components with 'train' label.
with
logger
.
mode
(
"train"
):
# create a scalar component called 'scalars/'
scalar_pytorch_train_loss
=
logger
.
scalar
(
"scalars/scalar_pytorch_train_loss"
)
image1
=
logger
.
image
(
"images/image1"
,
1
)
image2
=
logger
.
image
(
"images/image2"
,
1
)
histogram0
=
logger
.
histogram
(
"histogram/histogram0"
,
num_buckets
=
100
)
# get some random training images
dataiter
=
iter
(
trainloader
)
images
,
labels
=
dataiter
.
next
()
# show images
imshow
(
torchvision
.
utils
.
make_grid
(
images
))
# print labels
print
(
' '
.
join
(
'%5s'
%
classes
[
labels
[
j
]]
for
j
in
range
(
4
)))
# Define a Convolution Neural Network
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
3
,
6
,
5
)
self
.
pool
=
nn
.
MaxPool2d
(
2
,
2
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
)
self
.
fc1
=
nn
.
Linear
(
16
*
5
*
5
,
120
)
self
.
fc2
=
nn
.
Linear
(
120
,
84
)
self
.
fc3
=
nn
.
Linear
(
84
,
10
)
def
forward
(
self
,
x
):
x
=
self
.
pool
(
F
.
relu
(
self
.
conv1
(
x
)))
x
=
self
.
pool
(
F
.
relu
(
self
.
conv2
(
x
)))
x
=
x
.
view
(
-
1
,
16
*
5
*
5
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
F
.
relu
(
self
.
fc2
(
x
))
x
=
self
.
fc3
(
x
)
return
x
net
=
Net
()
# Define a Loss function and optimizer
criterion
=
nn
.
CrossEntropyLoss
()
optimizer
=
optim
.
SGD
(
net
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
)
train_step
=
0
# Train the network
for
epoch
in
range
(
5
):
# loop over the dataset multiple times
running_loss
=
0.0
for
i
,
data
in
enumerate
(
trainloader
,
0
):
# get the inputs
inputs
,
labels
=
data
# wrap them in Variable
inputs
,
labels
=
Variable
(
inputs
),
Variable
(
labels
)
# zero the parameter gradients
optimizer
.
zero_grad
()
# forward + backward + optimize
outputs
=
net
(
inputs
)
loss
=
criterion
(
outputs
,
labels
)
loss
.
backward
()
optimizer
.
step
()
# use VisualDL to retrieve metrics
# scalar
scalar_pytorch_train_loss
.
add_record
(
train_step
,
float
(
loss
))
# histogram
weight_list
=
net
.
conv1
.
weight
.
view
(
6
*
3
*
5
*
5
,
-
1
)
histogram0
.
add_record
(
train_step
,
weight_list
)
# image
image1
.
start_sampling
()
image1
.
add_sample
([
96
,
25
],
net
.
conv2
.
weight
.
view
(
16
*
6
*
5
*
5
,
-
1
))
image1
.
finish_sampling
()
image2
.
start_sampling
()
image2
.
add_sample
([
18
,
25
],
net
.
conv1
.
weight
.
view
(
6
*
3
*
5
*
5
,
-
1
))
image2
.
finish_sampling
()
train_step
+=
1
# print statistics
running_loss
+=
loss
.
data
[
0
]
if
i
%
2000
==
1999
:
# print every 2000 mini-batches
print
(
'[%d, %5d] loss: %.3f'
%
(
epoch
+
1
,
i
+
1
,
running_loss
/
2000
))
running_loss
=
0.0
print
(
'Finished Training'
)
dummy_input
=
Variable
(
torch
.
randn
(
4
,
3
,
32
,
32
))
torch
.
onnx
.
export
(
net
,
dummy_input
,
"pytorch_cifar10.onnx"
)
print
(
'Done'
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录