提交 30be2502 编写于 作者: M Macrobull

add UNet and YoloV2 samples

上级 63ac4c2c
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 22 11:19:45 2019
@author: Macrobull
Not all ops in this file are supported by both Pytorch and ONNX
This only demostrates the conversion/validation workflow from Pytorch to ONNX to Paddle fluid
"""
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from onnx2fluid.torch_export_helper import export_onnx_with_validation
# from https://github.com/milesial/Pytorch-UNet
class double_conv(nn.Module):
'''(conv => BN => ReLU) * 2'''
def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
def forward(self, x):
x = self.conv(x)
return x
class inconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(inconv, self).__init__()
self.conv = double_conv(in_ch, out_ch)
def forward(self, x):
x = self.conv(x)
return x
class down(nn.Module):
def __init__(self, in_ch, out_ch):
super(down, self).__init__()
self.mpconv = nn.Sequential(nn.MaxPool2d(2), double_conv(in_ch, out_ch))
def forward(self, x):
x = self.mpconv(x)
return x
class up(nn.Module):
def __init__(self, in_ch, out_ch, bilinear=True):
super(up, self).__init__()
# would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights
if bilinear:
self.up = nn.Upsample(
scale_factor=2, mode='bilinear') #, align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
self.conv = double_conv(in_ch, out_ch)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
if hasattr(self, 'diffY'):
diffY = self.diffY
diffX = self.diffX
else:
diffY = self.diffY = x2.size()[2] - x1.size()[2]
diffX = self.diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(
x1,
(diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2))
# for padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1)
def forward(self, x):
x = self.conv(x)
return x
class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.inc = inconv(n_channels, 64)
self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 512)
self.up1 = up(1024, 256)
self.up2 = up(512, 128)
self.up3 = up(256, 64)
self.up4 = up(128, 64)
self.outc = outconv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
return F.sigmoid(x)
model = UNet(3, 80)
model.eval()
xb = torch.rand((1, 3, 512, 512))
yp = model(xb)
export_onnx_with_validation(
model, (xb, ),
'sample_unet', ['image'], ['pred'],
verbose=True,
training=False)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 22 11:19:45 2019
@author: Macrobull
Not all ops in this file are supported by both Pytorch and ONNX
This only demostrates the conversion/validation workflow from Pytorch to ONNX to Paddle fluid
"""
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from onnx2fluid.torch_export_helper import export_onnx_with_validation
# from https://github.com/santoshgsk/yolov2-pytorch/blob/master/yolotorch.ipynb
class Yolov2(nn.Module):
def __init__(self):
super(Yolov2, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=32,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.batchnorm1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.batchnorm2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(
in_channels=64,
out_channels=128,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.batchnorm3 = nn.BatchNorm2d(128)
self.conv4 = nn.Conv2d(
in_channels=128,
out_channels=64,
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.batchnorm4 = nn.BatchNorm2d(64)
self.conv5 = nn.Conv2d(
in_channels=64,
out_channels=128,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.batchnorm5 = nn.BatchNorm2d(128)
self.conv6 = nn.Conv2d(
in_channels=128,
out_channels=256,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.batchnorm6 = nn.BatchNorm2d(256)
self.conv7 = nn.Conv2d(
in_channels=256,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.batchnorm7 = nn.BatchNorm2d(128)
self.conv8 = nn.Conv2d(
in_channels=128,
out_channels=256,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.batchnorm8 = nn.BatchNorm2d(256)
self.conv9 = nn.Conv2d(
in_channels=256,
out_channels=512,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.batchnorm9 = nn.BatchNorm2d(512)
self.conv10 = nn.Conv2d(
in_channels=512,
out_channels=256,
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.batchnorm10 = nn.BatchNorm2d(256)
self.conv11 = nn.Conv2d(
in_channels=256,
out_channels=512,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.batchnorm11 = nn.BatchNorm2d(512)
self.conv12 = nn.Conv2d(
in_channels=512,
out_channels=256,
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.batchnorm12 = nn.BatchNorm2d(256)
self.conv13 = nn.Conv2d(
in_channels=256,
out_channels=512,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.batchnorm13 = nn.BatchNorm2d(512)
self.conv14 = nn.Conv2d(
in_channels=512,
out_channels=1024,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.batchnorm14 = nn.BatchNorm2d(1024)
self.conv15 = nn.Conv2d(
in_channels=1024,
out_channels=512,
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.batchnorm15 = nn.BatchNorm2d(512)
self.conv16 = nn.Conv2d(
in_channels=512,
out_channels=1024,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.batchnorm16 = nn.BatchNorm2d(1024)
self.conv17 = nn.Conv2d(
in_channels=1024,
out_channels=512,
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.batchnorm17 = nn.BatchNorm2d(512)
self.conv18 = nn.Conv2d(
in_channels=512,
out_channels=1024,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.batchnorm18 = nn.BatchNorm2d(1024)
self.conv19 = nn.Conv2d(
in_channels=1024,
out_channels=1024,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.batchnorm19 = nn.BatchNorm2d(1024)
self.conv20 = nn.Conv2d(
in_channels=1024,
out_channels=1024,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.batchnorm20 = nn.BatchNorm2d(1024)
self.conv21 = nn.Conv2d(
in_channels=3072,
out_channels=1024,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.batchnorm21 = nn.BatchNorm2d(1024)
self.conv22 = nn.Conv2d(
in_channels=1024,
out_channels=125,
kernel_size=1,
stride=1,
padding=0)
def reorg_layer(self, x):
stride = 2
if hasattr(self, 'batch_size'):
batch_size, channels, height, width = self.batch_size, self.channels, self.height, self.width
new_ht = self.new_ht
new_wd = self.new_wd
new_channels = self.new_channels
else:
batch_size, channels, height, width = self.batch_size, self.channels, self.height, self.width = x.size(
)
new_ht = self.new_ht = height // stride
new_wd = self.new_wd = width // stride
new_channels = self.new_channels = channels * stride * stride
passthrough = x.permute(0, 2, 3, 1)
passthrough = passthrough.contiguous().view(-1, new_ht, stride, new_wd,
stride, channels)
passthrough = passthrough.permute(0, 1, 3, 2, 4, 5)
passthrough = passthrough.contiguous().view(-1, new_ht, new_wd,
new_channels)
passthrough = passthrough.permute(0, 3, 1, 2)
return passthrough
def forward(self, x):
out = F.max_pool2d(
F.leaky_relu(self.batchnorm1(self.conv1(x)), negative_slope=0.1),
2,
stride=2)
out = F.max_pool2d(
F.leaky_relu(self.batchnorm2(self.conv2(out)), negative_slope=0.1),
2,
stride=2)
out = F.leaky_relu(self.batchnorm3(self.conv3(out)), negative_slope=0.1)
out = F.leaky_relu(self.batchnorm4(self.conv4(out)), negative_slope=0.1)
out = F.leaky_relu(self.batchnorm5(self.conv5(out)), negative_slope=0.1)
out = F.max_pool2d(out, 2, stride=2)
out = F.leaky_relu(self.batchnorm6(self.conv6(out)), negative_slope=0.1)
out = F.leaky_relu(self.batchnorm7(self.conv7(out)), negative_slope=0.1)
out = F.leaky_relu(self.batchnorm8(self.conv8(out)), negative_slope=0.1)
out = F.max_pool2d(out, 2, stride=2)
out = F.leaky_relu(self.batchnorm9(self.conv9(out)), negative_slope=0.1)
out = F.leaky_relu(
self.batchnorm10(self.conv10(out)), negative_slope=0.1)
out = F.leaky_relu(
self.batchnorm11(self.conv11(out)), negative_slope=0.1)
out = F.leaky_relu(
self.batchnorm12(self.conv12(out)), negative_slope=0.1)
out = F.leaky_relu(
self.batchnorm13(self.conv13(out)), negative_slope=0.1)
passthrough = self.reorg_layer(out)
out = F.max_pool2d(out, 2, stride=2)
out = F.leaky_relu(
self.batchnorm14(self.conv14(out)), negative_slope=0.1)
out = F.leaky_relu(
self.batchnorm15(self.conv15(out)), negative_slope=0.1)
out = F.leaky_relu(
self.batchnorm16(self.conv16(out)), negative_slope=0.1)
out = F.leaky_relu(
self.batchnorm17(self.conv17(out)), negative_slope=0.1)
out = F.leaky_relu(
self.batchnorm18(self.conv18(out)), negative_slope=0.1)
out = F.leaky_relu(
self.batchnorm19(self.conv19(out)), negative_slope=0.1)
out = F.leaky_relu(
self.batchnorm20(self.conv20(out)), negative_slope=0.1)
out = torch.cat([passthrough, out], 1)
out = F.leaky_relu(
self.batchnorm21(self.conv21(out)), negative_slope=0.1)
out = self.conv22(out)
return out
model = Yolov2()
model.eval()
xb = torch.rand((1, 3, 224, 224))
yp = model(xb)
export_onnx_with_validation(
model, (xb, ),
'sample_yolov2', ['image'], ['pred'],
verbose=True,
training=False)
......@@ -77,7 +77,7 @@ DEFAULT_OP_MAPPING = {
'Sqrt': ['sqrt', ['X'], ['Out']],
'Tanh': ['tanh', ['X'], ['Out']],
'ThresholdedRelu': ['thresholded_relu', ['X'], ['Out'], dict(alpha='threshold')],
'Transpose': ['transpose', ['X'], ['Out']], # FIXME: emit transpose2
# 'Transpose': ['transpose', ['X'], ['Out']],
'Unsqueeze': ['unsqueeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit unsqueeze2
## binary ops ##
'Add': ['elementwise_add', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
......@@ -1779,6 +1779,45 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
)
def Transpose(prog, inputs, outputs, attrs, *args, name='', **kwargs):
"""
onnx::Transpose-1:
"""
# I/O
val_data, = inputs
val_transposed, = outputs
var_data = _make_var_name(val_data)
var_transposed = _make_var_name(val_transposed)
# interpretation
fluid_op = 'transpose'
perm = attrs['perm'] # required
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
prog.Code('{} = layers.{}({}'
', perm={}'
'{})'.format(
var_transposed,
fluid_op,
var_data,
# attrs
perm,
name_attr,
))
fluid_op = 'transpose2'
var_xshape = name + '.xshape' # dummy output
prog.VarDesc(var_xshape)
prog.VarDesc(var_transposed)
prog.OpDesc(
fluid_op,
([var_data], 'X'),
([var_transposed, var_xshape], 'Out', 'XShape'),
dict(axis=perm), # f**k you API
)
def Upsample(prog,
inputs,
outputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册