Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
c224c953
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
接近 2 年 前同步成功
通知
329
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c224c953
编写于
4月 03, 2019
作者:
J
Jason
提交者:
GitHub
4月 03, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #20 from MacroBull/master
Add more samples
上级
6e1eb980
30be2502
变更
8
展开全部
隐藏空白更改
内联
并排
Showing
8 changed file
with
736 addition
and
142 deletion
+736
-142
onnx2fluid/examples/gen_some_samples.py
onnx2fluid/examples/gen_some_samples.py
+9
-0
onnx2fluid/examples/gen_unet.py
onnx2fluid/examples/gen_unet.py
+142
-0
onnx2fluid/examples/gen_yolov2.py
onnx2fluid/examples/gen_yolov2.py
+297
-0
onnx2fluid/onnx2fluid/conversion.py
onnx2fluid/onnx2fluid/conversion.py
+3
-1
onnx2fluid/onnx2fluid/onnx_utils.py
onnx2fluid/onnx2fluid/onnx_utils.py
+10
-5
onnx2fluid/onnx2fluid/symbolic.py
onnx2fluid/onnx2fluid/symbolic.py
+246
-122
onnx2fluid/onnx2fluid/validation.py
onnx2fluid/onnx2fluid/validation.py
+1
-1
onnx2fluid/onnx2fluid/writer.py
onnx2fluid/onnx2fluid/writer.py
+28
-13
未找到文件。
onnx2fluid/examples/gen_some_samples.py
浏览文件 @
c224c953
...
@@ -35,6 +35,7 @@ idx = 0
...
@@ -35,6 +35,7 @@ idx = 0
#
#
#
#
#model = Model()
#model = Model()
#model.eval()
#xb = torch.rand((2, 3, 4))
#xb = torch.rand((2, 3, 4))
#yp = model(xb)
#yp = model(xb)
#idx += 1
#idx += 1
...
@@ -56,6 +57,7 @@ idx = 0
...
@@ -56,6 +57,7 @@ idx = 0
#
#
#
#
#model = Model()
#model = Model()
#model.eval()
#xb = torch.rand((2, 3))
#xb = torch.rand((2, 3))
#yp = model(xb)
#yp = model(xb)
#idx += 1
#idx += 1
...
@@ -79,6 +81,7 @@ class Model(nn.Module):
...
@@ -79,6 +81,7 @@ class Model(nn.Module):
model
=
Model
()
model
=
Model
()
model
.
eval
()
xb
=
torch
.
rand
((
2
,
3
))
xb
=
torch
.
rand
((
2
,
3
))
yp
=
model
(
xb
)
yp
=
model
(
xb
)
idx
+=
1
idx
+=
1
...
@@ -105,6 +108,7 @@ class Model(nn.Module):
...
@@ -105,6 +108,7 @@ class Model(nn.Module):
model
=
Model
()
model
=
Model
()
model
.
eval
()
xb0
=
torch
.
rand
((
2
,
3
))
xb0
=
torch
.
rand
((
2
,
3
))
xb1
=
torch
.
rand
((
2
,
3
))
xb1
=
torch
.
rand
((
2
,
3
))
ya
,
yb
,
yc
=
model
(
xb0
,
xb1
)
ya
,
yb
,
yc
=
model
(
xb0
,
xb1
)
...
@@ -129,6 +133,7 @@ class Model(nn.Module):
...
@@ -129,6 +133,7 @@ class Model(nn.Module):
model
=
Model
()
model
=
Model
()
model
.
eval
()
theta
=
torch
.
rand
((
2
,
2
,
3
))
theta
=
torch
.
rand
((
2
,
2
,
3
))
grid
=
model
(
theta
)
grid
=
model
(
theta
)
idx
+=
1
idx
+=
1
...
@@ -156,6 +161,7 @@ class Model(nn.Module):
...
@@ -156,6 +161,7 @@ class Model(nn.Module):
model
=
Model
()
model
=
Model
()
model
.
eval
()
xb
=
torch
.
rand
((
2
,
3
,
4
,
5
))
xb
=
torch
.
rand
((
2
,
3
,
4
,
5
))
yp
=
model
(
xb
)
yp
=
model
(
xb
)
idx
+=
1
idx
+=
1
...
@@ -185,6 +191,7 @@ class Model(nn.Module):
...
@@ -185,6 +191,7 @@ class Model(nn.Module):
model
=
Model
()
model
=
Model
()
model
.
eval
()
xb
=
torch
.
rand
((
2
,
3
,
4
,
5
))
xb
=
torch
.
rand
((
2
,
3
,
4
,
5
))
yp
=
model
(
xb
)
yp
=
model
(
xb
)
idx
+=
1
idx
+=
1
...
@@ -209,6 +216,7 @@ export_onnx_with_validation(
...
@@ -209,6 +216,7 @@ export_onnx_with_validation(
#
#
#
#
#model = Model()
#model = Model()
#model.eval()
#xb = torch.rand((2, 3, 4, 5))
#xb = torch.rand((2, 3, 4, 5))
#yp = model(xb)
#yp = model(xb)
#idx += 1
#idx += 1
...
@@ -229,6 +237,7 @@ class Model(nn.Module):
...
@@ -229,6 +237,7 @@ class Model(nn.Module):
model
=
Model
()
model
=
Model
()
model
.
eval
()
xb
=
torch
.
rand
((
2
,
3
))
xb
=
torch
.
rand
((
2
,
3
))
yp
=
model
(
xb
)
yp
=
model
(
xb
)
idx
+=
1
idx
+=
1
...
...
onnx2fluid/examples/gen_unet.py
0 → 100644
浏览文件 @
c224c953
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 22 11:19:45 2019
@author: Macrobull
Not all ops in this file are supported by both Pytorch and ONNX
This only demostrates the conversion/validation workflow from Pytorch to ONNX to Paddle fluid
"""
from
__future__
import
print_function
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
onnx2fluid.torch_export_helper
import
export_onnx_with_validation
# from https://github.com/milesial/Pytorch-UNet
class
double_conv
(
nn
.
Module
):
'''(conv => BN => ReLU) * 2'''
def
__init__
(
self
,
in_ch
,
out_ch
):
super
(
double_conv
,
self
).
__init__
()
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_ch
,
out_ch
,
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
out_ch
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv2d
(
out_ch
,
out_ch
,
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
out_ch
),
nn
.
ReLU
(
inplace
=
True
))
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
return
x
class
inconv
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
,
out_ch
):
super
(
inconv
,
self
).
__init__
()
self
.
conv
=
double_conv
(
in_ch
,
out_ch
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
return
x
class
down
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
,
out_ch
):
super
(
down
,
self
).
__init__
()
self
.
mpconv
=
nn
.
Sequential
(
nn
.
MaxPool2d
(
2
),
double_conv
(
in_ch
,
out_ch
))
def
forward
(
self
,
x
):
x
=
self
.
mpconv
(
x
)
return
x
class
up
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
,
out_ch
,
bilinear
=
True
):
super
(
up
,
self
).
__init__
()
# would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights
if
bilinear
:
self
.
up
=
nn
.
Upsample
(
scale_factor
=
2
,
mode
=
'bilinear'
)
#, align_corners=True)
else
:
self
.
up
=
nn
.
ConvTranspose2d
(
in_ch
//
2
,
in_ch
//
2
,
2
,
stride
=
2
)
self
.
conv
=
double_conv
(
in_ch
,
out_ch
)
def
forward
(
self
,
x1
,
x2
):
x1
=
self
.
up
(
x1
)
# input is CHW
if
hasattr
(
self
,
'diffY'
):
diffY
=
self
.
diffY
diffX
=
self
.
diffX
else
:
diffY
=
self
.
diffY
=
x2
.
size
()[
2
]
-
x1
.
size
()[
2
]
diffX
=
self
.
diffX
=
x2
.
size
()[
3
]
-
x1
.
size
()[
3
]
x1
=
F
.
pad
(
x1
,
(
diffX
//
2
,
diffX
-
diffX
//
2
,
diffY
//
2
,
diffY
-
diffY
//
2
))
# for padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x
=
torch
.
cat
([
x2
,
x1
],
dim
=
1
)
x
=
self
.
conv
(
x
)
return
x
class
outconv
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
,
out_ch
):
super
(
outconv
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2d
(
in_ch
,
out_ch
,
1
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
return
x
class
UNet
(
nn
.
Module
):
def
__init__
(
self
,
n_channels
,
n_classes
):
super
(
UNet
,
self
).
__init__
()
self
.
inc
=
inconv
(
n_channels
,
64
)
self
.
down1
=
down
(
64
,
128
)
self
.
down2
=
down
(
128
,
256
)
self
.
down3
=
down
(
256
,
512
)
self
.
down4
=
down
(
512
,
512
)
self
.
up1
=
up
(
1024
,
256
)
self
.
up2
=
up
(
512
,
128
)
self
.
up3
=
up
(
256
,
64
)
self
.
up4
=
up
(
128
,
64
)
self
.
outc
=
outconv
(
64
,
n_classes
)
def
forward
(
self
,
x
):
x1
=
self
.
inc
(
x
)
x2
=
self
.
down1
(
x1
)
x3
=
self
.
down2
(
x2
)
x4
=
self
.
down3
(
x3
)
x5
=
self
.
down4
(
x4
)
x
=
self
.
up1
(
x5
,
x4
)
x
=
self
.
up2
(
x
,
x3
)
x
=
self
.
up3
(
x
,
x2
)
x
=
self
.
up4
(
x
,
x1
)
x
=
self
.
outc
(
x
)
return
F
.
sigmoid
(
x
)
model
=
UNet
(
3
,
80
)
model
.
eval
()
xb
=
torch
.
rand
((
1
,
3
,
512
,
512
))
yp
=
model
(
xb
)
export_onnx_with_validation
(
model
,
(
xb
,
),
'sample_unet'
,
[
'image'
],
[
'pred'
],
verbose
=
True
,
training
=
False
)
onnx2fluid/examples/gen_yolov2.py
0 → 100644
浏览文件 @
c224c953
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 22 11:19:45 2019
@author: Macrobull
Not all ops in this file are supported by both Pytorch and ONNX
This only demostrates the conversion/validation workflow from Pytorch to ONNX to Paddle fluid
"""
from
__future__
import
print_function
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
onnx2fluid.torch_export_helper
import
export_onnx_with_validation
# from https://github.com/santoshgsk/yolov2-pytorch/blob/master/yolotorch.ipynb
class
Yolov2
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Yolov2
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
in_channels
=
3
,
out_channels
=
32
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
batchnorm1
=
nn
.
BatchNorm2d
(
32
)
self
.
conv2
=
nn
.
Conv2d
(
in_channels
=
32
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
batchnorm2
=
nn
.
BatchNorm2d
(
64
)
self
.
conv3
=
nn
.
Conv2d
(
in_channels
=
64
,
out_channels
=
128
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
batchnorm3
=
nn
.
BatchNorm2d
(
128
)
self
.
conv4
=
nn
.
Conv2d
(
in_channels
=
128
,
out_channels
=
64
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
)
self
.
batchnorm4
=
nn
.
BatchNorm2d
(
64
)
self
.
conv5
=
nn
.
Conv2d
(
in_channels
=
64
,
out_channels
=
128
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
batchnorm5
=
nn
.
BatchNorm2d
(
128
)
self
.
conv6
=
nn
.
Conv2d
(
in_channels
=
128
,
out_channels
=
256
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
batchnorm6
=
nn
.
BatchNorm2d
(
256
)
self
.
conv7
=
nn
.
Conv2d
(
in_channels
=
256
,
out_channels
=
128
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
)
self
.
batchnorm7
=
nn
.
BatchNorm2d
(
128
)
self
.
conv8
=
nn
.
Conv2d
(
in_channels
=
128
,
out_channels
=
256
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
batchnorm8
=
nn
.
BatchNorm2d
(
256
)
self
.
conv9
=
nn
.
Conv2d
(
in_channels
=
256
,
out_channels
=
512
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
batchnorm9
=
nn
.
BatchNorm2d
(
512
)
self
.
conv10
=
nn
.
Conv2d
(
in_channels
=
512
,
out_channels
=
256
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
)
self
.
batchnorm10
=
nn
.
BatchNorm2d
(
256
)
self
.
conv11
=
nn
.
Conv2d
(
in_channels
=
256
,
out_channels
=
512
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
batchnorm11
=
nn
.
BatchNorm2d
(
512
)
self
.
conv12
=
nn
.
Conv2d
(
in_channels
=
512
,
out_channels
=
256
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
)
self
.
batchnorm12
=
nn
.
BatchNorm2d
(
256
)
self
.
conv13
=
nn
.
Conv2d
(
in_channels
=
256
,
out_channels
=
512
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
batchnorm13
=
nn
.
BatchNorm2d
(
512
)
self
.
conv14
=
nn
.
Conv2d
(
in_channels
=
512
,
out_channels
=
1024
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
batchnorm14
=
nn
.
BatchNorm2d
(
1024
)
self
.
conv15
=
nn
.
Conv2d
(
in_channels
=
1024
,
out_channels
=
512
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
)
self
.
batchnorm15
=
nn
.
BatchNorm2d
(
512
)
self
.
conv16
=
nn
.
Conv2d
(
in_channels
=
512
,
out_channels
=
1024
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
batchnorm16
=
nn
.
BatchNorm2d
(
1024
)
self
.
conv17
=
nn
.
Conv2d
(
in_channels
=
1024
,
out_channels
=
512
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
)
self
.
batchnorm17
=
nn
.
BatchNorm2d
(
512
)
self
.
conv18
=
nn
.
Conv2d
(
in_channels
=
512
,
out_channels
=
1024
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
batchnorm18
=
nn
.
BatchNorm2d
(
1024
)
self
.
conv19
=
nn
.
Conv2d
(
in_channels
=
1024
,
out_channels
=
1024
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
batchnorm19
=
nn
.
BatchNorm2d
(
1024
)
self
.
conv20
=
nn
.
Conv2d
(
in_channels
=
1024
,
out_channels
=
1024
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
batchnorm20
=
nn
.
BatchNorm2d
(
1024
)
self
.
conv21
=
nn
.
Conv2d
(
in_channels
=
3072
,
out_channels
=
1024
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
batchnorm21
=
nn
.
BatchNorm2d
(
1024
)
self
.
conv22
=
nn
.
Conv2d
(
in_channels
=
1024
,
out_channels
=
125
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
reorg_layer
(
self
,
x
):
stride
=
2
if
hasattr
(
self
,
'batch_size'
):
batch_size
,
channels
,
height
,
width
=
self
.
batch_size
,
self
.
channels
,
self
.
height
,
self
.
width
new_ht
=
self
.
new_ht
new_wd
=
self
.
new_wd
new_channels
=
self
.
new_channels
else
:
batch_size
,
channels
,
height
,
width
=
self
.
batch_size
,
self
.
channels
,
self
.
height
,
self
.
width
=
x
.
size
(
)
new_ht
=
self
.
new_ht
=
height
//
stride
new_wd
=
self
.
new_wd
=
width
//
stride
new_channels
=
self
.
new_channels
=
channels
*
stride
*
stride
passthrough
=
x
.
permute
(
0
,
2
,
3
,
1
)
passthrough
=
passthrough
.
contiguous
().
view
(
-
1
,
new_ht
,
stride
,
new_wd
,
stride
,
channels
)
passthrough
=
passthrough
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
)
passthrough
=
passthrough
.
contiguous
().
view
(
-
1
,
new_ht
,
new_wd
,
new_channels
)
passthrough
=
passthrough
.
permute
(
0
,
3
,
1
,
2
)
return
passthrough
def
forward
(
self
,
x
):
out
=
F
.
max_pool2d
(
F
.
leaky_relu
(
self
.
batchnorm1
(
self
.
conv1
(
x
)),
negative_slope
=
0.1
),
2
,
stride
=
2
)
out
=
F
.
max_pool2d
(
F
.
leaky_relu
(
self
.
batchnorm2
(
self
.
conv2
(
out
)),
negative_slope
=
0.1
),
2
,
stride
=
2
)
out
=
F
.
leaky_relu
(
self
.
batchnorm3
(
self
.
conv3
(
out
)),
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
self
.
batchnorm4
(
self
.
conv4
(
out
)),
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
self
.
batchnorm5
(
self
.
conv5
(
out
)),
negative_slope
=
0.1
)
out
=
F
.
max_pool2d
(
out
,
2
,
stride
=
2
)
out
=
F
.
leaky_relu
(
self
.
batchnorm6
(
self
.
conv6
(
out
)),
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
self
.
batchnorm7
(
self
.
conv7
(
out
)),
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
self
.
batchnorm8
(
self
.
conv8
(
out
)),
negative_slope
=
0.1
)
out
=
F
.
max_pool2d
(
out
,
2
,
stride
=
2
)
out
=
F
.
leaky_relu
(
self
.
batchnorm9
(
self
.
conv9
(
out
)),
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
self
.
batchnorm10
(
self
.
conv10
(
out
)),
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
self
.
batchnorm11
(
self
.
conv11
(
out
)),
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
self
.
batchnorm12
(
self
.
conv12
(
out
)),
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
self
.
batchnorm13
(
self
.
conv13
(
out
)),
negative_slope
=
0.1
)
passthrough
=
self
.
reorg_layer
(
out
)
out
=
F
.
max_pool2d
(
out
,
2
,
stride
=
2
)
out
=
F
.
leaky_relu
(
self
.
batchnorm14
(
self
.
conv14
(
out
)),
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
self
.
batchnorm15
(
self
.
conv15
(
out
)),
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
self
.
batchnorm16
(
self
.
conv16
(
out
)),
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
self
.
batchnorm17
(
self
.
conv17
(
out
)),
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
self
.
batchnorm18
(
self
.
conv18
(
out
)),
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
self
.
batchnorm19
(
self
.
conv19
(
out
)),
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
self
.
batchnorm20
(
self
.
conv20
(
out
)),
negative_slope
=
0.1
)
out
=
torch
.
cat
([
passthrough
,
out
],
1
)
out
=
F
.
leaky_relu
(
self
.
batchnorm21
(
self
.
conv21
(
out
)),
negative_slope
=
0.1
)
out
=
self
.
conv22
(
out
)
return
out
model
=
Yolov2
()
model
.
eval
()
xb
=
torch
.
rand
((
1
,
3
,
224
,
224
))
yp
=
model
(
xb
)
export_onnx_with_validation
(
model
,
(
xb
,
),
'sample_yolov2'
,
[
'image'
],
[
'pred'
],
verbose
=
True
,
training
=
False
)
onnx2fluid/onnx2fluid/conversion.py
浏览文件 @
c224c953
...
@@ -77,6 +77,7 @@ def convert(onnx_model_filename,
...
@@ -77,6 +77,7 @@ def convert(onnx_model_filename,
logger
.
warning
(
'the ONNX model sanity checking error is suppressed'
)
logger
.
warning
(
'the ONNX model sanity checking error is suppressed'
)
logger
.
warning
(
'value_info inferring may be uncompleted'
)
logger
.
warning
(
'value_info inferring may be uncompleted'
)
# onnx model optimization
# onnx model optimization
logger
.
info
(
'model has %d ops'
,
len
(
onnx_model
.
graph
.
node
))
logger
.
info
(
'optimizing model ...'
)
logger
.
info
(
'optimizing model ...'
)
onnx_model
=
optimize_model_skip_op_for_inference
(
onnx_model
)
onnx_model
=
optimize_model_skip_op_for_inference
(
onnx_model
)
onnx_model
=
optimize_model_strip_initializer
(
onnx_model
)
onnx_model
=
optimize_model_strip_initializer
(
onnx_model
)
...
@@ -142,7 +143,8 @@ def convert(onnx_model_filename,
...
@@ -142,7 +143,8 @@ def convert(onnx_model_filename,
raise
e
raise
e
op_codes
=
fluid_program
.
codes
op_codes
=
fluid_program
.
codes
fluid_program
.
codes
=
[]
fluid_program
.
codes
=
[]
logger
.
info
(
'%d ops converted'
,
len
(
fluid_program
.
op_descs
))
logger
.
info
(
'%d ops in, %d ops out'
,
len
(
onnx_graph
.
node
),
len
(
fluid_program
.
op_descs
))
# weight writer
# weight writer
for
name
,
weight
in
graph_weights
(
onnx_graph
):
for
name
,
weight
in
graph_weights
(
onnx_graph
):
...
...
onnx2fluid/onnx2fluid/onnx_utils.py
浏览文件 @
c224c953
...
@@ -80,7 +80,7 @@ def build_value_refs(nodes):
...
@@ -80,7 +80,7 @@ def build_value_refs(nodes):
def
get_attribute_value2
(
attr
):
def
get_attribute_value2
(
attr
):
"""
"""
get_attribute_value
with tensor conversion
get_attribute_value
enhanced
"""
"""
if
attr
.
type
==
onnx
.
AttributeProto
.
TENSOR
:
if
attr
.
type
==
onnx
.
AttributeProto
.
TENSOR
:
...
@@ -88,6 +88,9 @@ def get_attribute_value2(attr):
...
@@ -88,6 +88,9 @@ def get_attribute_value2(attr):
data
=
attr
.
t
.
raw_data
data
=
attr
.
t
.
raw_data
value
=
np
.
frombuffer
(
value
=
np
.
frombuffer
(
data
,
dtype
=
dtype
,
count
=
(
len
(
data
)
//
dtype
.
itemsize
))
data
,
dtype
=
dtype
,
count
=
(
len
(
data
)
//
dtype
.
itemsize
))
elif
attr
.
type
==
onnx
.
AttributeProto
.
STRING
:
value
=
attr
.
s
value
=
value
.
decode
()
if
isinstance
(
value
,
bytes
)
else
value
else
:
else
:
value
=
get_attribute_value
(
attr
)
value
=
get_attribute_value
(
attr
)
return
value
return
value
...
@@ -127,8 +130,10 @@ def node_topo(nodes, topo='default'):
...
@@ -127,8 +130,10 @@ def node_topo(nodes, topo='default'):
return
list
(
range
(
len
(
nodes
)))
return
list
(
range
(
len
(
nodes
)))
node_topo
=
[]
node_topo
=
[]
node_in_degrees
=
[
len
(
node
.
input
)
for
node
in
nodes
]
node_in_degrees
=
[
len
(
set
(
node
.
input
))
node_out_degrees
=
[
len
(
node
.
output
)
for
node
in
nodes
]
for
node
in
nodes
]
# merge multiple references
node_out_degrees
=
[
len
(
set
(
node
.
output
))
for
node
in
nodes
]
# merge multiple references
input_refs
,
output_refs
=
build_value_refs
(
nodes
)
input_refs
,
output_refs
=
build_value_refs
(
nodes
)
if
topo
==
'forward'
:
if
topo
==
'forward'
:
...
@@ -395,7 +400,7 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
...
@@ -395,7 +400,7 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
ret_inputs
.
add
().
CopyFrom
(
item
)
ret_inputs
.
add
().
CopyFrom
(
item
)
else
:
else
:
logger
.
debug
(
'input %s(%s%s) stripped'
,
name
,
tensor_dtype
(
item
),
logger
.
debug
(
'input %s(%s%s) stripped'
,
name
,
tensor_dtype
(
item
),
t
ensor_shape
(
item
))
t
uple
(
tensor_shape
(
item
)
))
return
ret
return
ret
...
@@ -422,7 +427,7 @@ def optimize_model_cast(model):
...
@@ -422,7 +427,7 @@ def optimize_model_cast(model):
attrs
=
node_attrs
(
node
)
attrs
=
node_attrs
(
node
)
output_dtype
=
TENSOR_TYPE_TO_NP_TYPE
[
attrs
[
'to'
]]
output_dtype
=
TENSOR_TYPE_TO_NP_TYPE
[
attrs
[
'to'
]]
input_name
=
node
.
input
[
0
]
input_name
=
node
.
input
[
0
]
info
=
value_info
.
get
(
'input_name'
,
None
)
# relax for un-inferrable
info
=
value_info
.
get
(
input_name
,
None
)
# relax for un-inferrable
if
info
is
None
:
if
info
is
None
:
continue
continue
input_dtype
=
info
.
get
(
'dtype'
,
None
)
input_dtype
=
info
.
get
(
'dtype'
,
None
)
...
...
onnx2fluid/onnx2fluid/symbolic.py
浏览文件 @
c224c953
此差异已折叠。
点击以展开。
onnx2fluid/onnx2fluid/validation.py
浏览文件 @
c224c953
...
@@ -54,7 +54,7 @@ def validate(fluid_model_filename,
...
@@ -54,7 +54,7 @@ def validate(fluid_model_filename,
# load model
# load model
fluid_model_dir
,
basename
=
os
.
path
.
split
(
fluid_model_filename
)
fluid_model_dir
,
basename
=
os
.
path
.
split
(
fluid_model_filename
)
if
basename
==
'__model__'
:
# is desc
model
if
basename
==
'__model__'
:
# is desc
program
logger
.
debug
(
'using desc file %s'
,
basename
)
logger
.
debug
(
'using desc file %s'
,
basename
)
prog
,
_
,
var_outs
=
fluid
.
io
.
load_inference_model
(
fluid_model_dir
,
exe
)
prog
,
_
,
var_outs
=
fluid
.
io
.
load_inference_model
(
fluid_model_dir
,
exe
)
out_names
=
var_outs
# HINT: pass var if fetch ops already created
out_names
=
var_outs
# HINT: pass var if fetch ops already created
...
...
onnx2fluid/onnx2fluid/writer.py
浏览文件 @
c224c953
...
@@ -139,19 +139,34 @@ class Program(object):
...
@@ -139,19 +139,34 @@ class Program(object):
elif
isinstance
(
value
,
str
):
elif
isinstance
(
value
,
str
):
od_attr
.
type
=
framework_pb2
.
STRING
od_attr
.
type
=
framework_pb2
.
STRING
od_attr
.
s
=
value
od_attr
.
s
=
value
elif
isinstance
(
value
,
list
)
and
len
(
value
)
>
0
:
elif
isinstance
(
value
,
list
):
if
isinstance
(
value
,
bool
):
# bool.mro() = [bool, int, object]
if
len
(
value
)
>
0
:
od_attr
.
type
=
framework_pb2
.
BOOLEANS
if
isinstance
(
value
,
od_attr
.
bools
.
extend
(
value
)
bool
):
# bool.mro() = [bool, int, object]
elif
isinstance
(
value
[
0
],
int
):
# only cast to int32 list
od_attr
.
type
=
framework_pb2
.
BOOLEANS
od_attr
.
type
=
framework_pb2
.
INTS
od_attr
.
bools
.
extend
(
value
)
od_attr
.
ints
.
extend
(
value
)
elif
isinstance
(
value
[
0
],
int
):
# only cast to int32 list
elif
isinstance
(
value
[
0
],
float
):
od_attr
.
type
=
framework_pb2
.
INTS
od_attr
.
type
=
framework_pb2
.
FLOATS
od_attr
.
ints
.
extend
(
value
)
od_attr
.
floats
.
extend
(
value
)
elif
isinstance
(
value
[
0
],
float
):
elif
isinstance
(
value
[
0
],
str
):
od_attr
.
type
=
framework_pb2
.
FLOATS
od_attr
.
type
=
framework_pb2
.
STRINGS
od_attr
.
floats
.
extend
(
value
)
od_attr
.
strings
.
extend
(
value
)
elif
isinstance
(
value
[
0
],
str
):
od_attr
.
type
=
framework_pb2
.
STRINGS
od_attr
.
strings
.
extend
(
value
)
else
:
raise
ValueError
(
'unsupported attribute {} = {}'
.
format
(
key
,
value
))
else
:
# WORKAROUND: shape of scalars is []
raise
ValueError
(
'unsupported attribute {} = {}'
.
format
(
key
,
value
))
# od_attr.type = framework_pb2.INTS
# logger.warning('using attribute %s = %s as INTS', key, value)
else
:
raise
ValueError
(
'unsupported attribute {} = {}'
.
format
(
key
,
value
))
od_attrs
.
append
(
od_attr
)
od_attrs
.
append
(
od_attr
)
return
od_attrs
return
od_attrs
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录