“23e5dafc47a6e75074c178d9d2ebe4d0508ac1b4”上不存在“api_cn/prune_api.html”
未验证 提交 84f717b2 编写于 作者: J Jason 提交者: GitHub

Merge pull request #34 from MacroBull/master

增加了很有用的类形推导后处理(为PaddleMobile平台)
...@@ -8,8 +8,6 @@ X2Paddle支持将Caffe和TensorFlow模型转至PaddlePaddle模型,同时我们 ...@@ -8,8 +8,6 @@ X2Paddle支持将Caffe和TensorFlow模型转至PaddlePaddle模型,同时我们
任何使用问题均可通过[ISSUE](https://github.com/PaddlePaddle/X2Paddle/issues)的方式及时反馈,或者也可直接通过pull request的方式一起更新代码和文档。 任何使用问题均可通过[ISSUE](https://github.com/PaddlePaddle/X2Paddle/issues)的方式及时反馈,或者也可直接通过pull request的方式一起更新代码和文档。
> **目前X2Paddle主要支持CV部分模型,对于NLP模型暂未支持。**
## [caffe2fluid](caffe2fluid) ## [caffe2fluid](caffe2fluid)
1. 支持将Caffe模型转至PaddlePaddle fluid可加载预测模型 1. 支持将Caffe模型转至PaddlePaddle fluid可加载预测模型
2. 提供Caffe-PaddlePaddle常用API的对比文档[[doc](caffe2fluid/doc)] 2. 提供Caffe-PaddlePaddle常用API的对比文档[[doc](caffe2fluid/doc)]
......
...@@ -17,13 +17,13 @@ onnx2fluid支持将ONNX模型转换为PaddlePaddle模型,并用于预测,用 ...@@ -17,13 +17,13 @@ onnx2fluid支持将ONNX模型转换为PaddlePaddle模型,并用于预测,用
在如下环境配置中测试成功: 在如下环境配置中测试成功:
* python 3.5+ * python 3.5+
* onnx == 1.4.0 * onnx == 1.4.1
* paddlepaddle == 1.3.0 (可选,仅用于验证) * paddlepaddle == 1.5.0 (可选,仅用于验证)
使用[Anaconda](https://docs.anaconda.com/anaconda/install): 使用[Anaconda](https://docs.anaconda.com/anaconda/install):
``` shell ``` shell
conda install -c conda-forge onnx conda install -c conda-forge onnx
pip install paddlepaddle==1.3.0 pip install paddlepaddle==1.5.0
``` ```
## 动手玩 ## 动手玩
...@@ -49,10 +49,12 @@ onnx2fluid sample_1.onnx -t sample_1.npz ...@@ -49,10 +49,12 @@ onnx2fluid sample_1.onnx -t sample_1.npz
## 使用说明 ## 使用说明
目前支持 **ONNX opset 9+** 的部分算子,对应PyTorch版本 **1.0/1.1(stable opset)**,更多兼容信息请参考[ONNX文档](https://github.com/onnx/onnx/blob/master/docs/Operators.md)
onnx2fluid: onnx2fluid:
```shell ```shell
onnx2fluid [-dexy] [-o /path/to/export_dir/] [-z archive.zip] [-t test_data.npz] /path/to/onnx/model.onnx onnx2fluid [-dexy] [-o /path/to/export_dir/] [-z archive.zip] [-t test_data.npz] [-i [input_name1,input_name2]] /path/to/onnx/model.onnx
optional arguments: optional arguments:
--debug, -d 启用调试 --debug, -d 启用调试
...@@ -63,6 +65,8 @@ optional arguments: ...@@ -63,6 +65,8 @@ optional arguments:
--output_dir, -o 指定输出目录 --output_dir, -o 指定输出目录
--archive [ARCHIVE], -z [ARCHIVE] --archive [ARCHIVE], -z [ARCHIVE]
如果验证通过,打包到指定的ZIP文件 如果验证通过,打包到指定的ZIP文件
--infer_inputs, -i [input_name1,input_name2]
调用PaddlePaddle fluid类形推导完善模型
``` ```
转换工具onnx2fluid.conversion: 转换工具onnx2fluid.conversion:
...@@ -74,10 +78,10 @@ onnx2fluid.conversion [-dexy] [-o /path/to/export_dir/] /path/to/onnx/model.onnx ...@@ -74,10 +78,10 @@ onnx2fluid.conversion [-dexy] [-o /path/to/export_dir/] /path/to/onnx/model.onnx
验证工具onnx2fluid.validate: 验证工具onnx2fluid.validate:
```shell ```shell
onnx2fluid.validate [-d] [-t test_data.npz] [-p 1e-3] /path/to/onnx/model.onnx onnx2fluid.validate [-d] [-t test_data.npz] [-i [input_name1,input_name2]] [-p 1e-3] /path/to/onnx/model.onnx
``` ```
## 参考 ## 参考
* PaddlePaddle [算子](http://www.paddlepaddle.org/documentation/docs/zh/1.4/api_cn/layers_cn.html) * PaddlePaddle [算子](http://www.paddlepaddle.org/documentation/docs/zh/1.5/api_cn/layers_cn.html)
* PaddlePaddle [加载预测模型](http://www.paddlepaddle.org/documentation/docs/zh/1.4/api_guides/low_level/inference.html#id4) * PaddlePaddle [加载预测模型](http://www.paddlepaddle.org/documentation/docs/zh/1.5/api_guides/low_level/inference.html#id4)
...@@ -19,8 +19,8 @@ PyTorch to Paddlepaddle model conversion can be easily achieved with PyTorch ONN ...@@ -19,8 +19,8 @@ PyTorch to Paddlepaddle model conversion can be easily achieved with PyTorch ONN
## Environment and dependency ## Environment and dependency
* python 3.5+ (python 2 not fully supported yet) * python 3.5+ (python 2 not fully supported yet)
* onnx == 1.4.0 * onnx >= 1.4
* paddlepaddle == 1.3.0 (optional for validation) * paddlepaddle >= 1.3.0 (optional for validation)
## Get started ## Get started
...@@ -47,10 +47,12 @@ onnx2fluid sample_unet.onnx -t sample_unet.npz ...@@ -47,10 +47,12 @@ onnx2fluid sample_unet.onnx -t sample_unet.npz
## Usage ## Usage
**ONNX opset 9+** is mainly supported, corresponded to PyTorch **1.0/1.1(stable opset)**,for more information: [ONNX doc](https://github.com/onnx/onnx/blob/master/docs/Operators.md)
onnx2fluid (all in one): onnx2fluid (all in one):
```shell ```shell
onnx2fluid [-dexy] [-o /path/to/export_dir/] [-z archive.zip] [-t test_data.npz] /path/to/onnx/model.onnx onnx2fluid [-dexy] [-o /path/to/export_dir/] [-z archive.zip] [-t test_data.npz] [-i [input_name1,input_name2]] /path/to/onnx/model.onnx
optional arguments: optional arguments:
--debug, -d enable debug logging and checking --debug, -d enable debug logging and checking
...@@ -61,6 +63,8 @@ optional arguments: ...@@ -61,6 +63,8 @@ optional arguments:
--output_dir, -o output directory --output_dir, -o output directory
--archive [ARCHIVE], -z [ARCHIVE] --archive [ARCHIVE], -z [ARCHIVE]
compress outputs to ZIP file if conversion successed compress outputs to ZIP file if conversion successed
--infer_inputs, -i [input_name1,input_name2]
invoke PaddlePaddle fluid type-shape inference
``` ```
onnx2fluid.conversion: onnx2fluid.conversion:
...@@ -72,10 +76,10 @@ onnx2fluid.conversion [-dexy] [-o /path/to/export_dir/] /path/to/onnx/model.onnx ...@@ -72,10 +76,10 @@ onnx2fluid.conversion [-dexy] [-o /path/to/export_dir/] /path/to/onnx/model.onnx
onnx2fluid.validate: onnx2fluid.validate:
```shell ```shell
onnx2fluid.validate [-d] [-t test_data.npz] [-p 1e-3] /path/to/onnx/model.onnx onnx2fluid.validate [-d] [-t test_data.npz] [-i [input_name1,input_name2]] [-p 1e-3] /path/to/onnx/model.onnx
``` ```
## Reference ## Reference
* [PaddlePaddle fluid operators](http://www.paddlepaddle.org/documentation/docs/en/1.4/api/layers.html) * [PaddlePaddle fluid operators](http://www.paddlepaddle.org/documentation/docs/en/1.5/api/layers.html)
* load converted model via [load_inference_model](http://www.paddlepaddle.org/documentation/docs/en/1.4/api/io.html#permalink-1-load_inference_model) * load converted model via [load_inference_model](http://www.paddlepaddle.org/documentation/docs/en/1.5/api/io.html#permalink-1-load_inference_model)
...@@ -12,16 +12,16 @@ import numpy as np ...@@ -12,16 +12,16 @@ import numpy as np
from collections import OrderedDict as Dict from collections import OrderedDict as Dict
def _make_var_name(name): def make_var_name(name):
""" """
make a valid variable name in Python code make a valid variable name in Python code
""" """
if name == '': assert name
return '_'
if name[0].isdigit(): if name[0].isdigit():
return 'var_' + name return 'var_' + name
for s in ' *?\\/-:': for s in ' \\|/:-': #
name = name.replace(s, '_') name = name.replace(s, '_')
if name.startswith('_'): if name.startswith('_'):
name = 'var' + name name = 'var' + name
...@@ -29,8 +29,8 @@ def _make_var_name(name): ...@@ -29,8 +29,8 @@ def _make_var_name(name):
fn = sys.argv[1] fn = sys.argv[1]
input_names = sys.argv[2].split(':') input_names = sys.argv[2].split(',')
output_name = sys.argv[3].split(':') output_names = sys.argv[3].split(',')
squeeze_data = len(sys.argv) > 4 squeeze_data = len(sys.argv) > 4
data = np.load(fn, encoding='bytes') data = np.load(fn, encoding='bytes')
...@@ -42,7 +42,7 @@ while squeeze_data and input_data.ndim > 4 and input_data.shape[0] == 1: ...@@ -42,7 +42,7 @@ while squeeze_data and input_data.ndim > 4 and input_data.shape[0] == 1:
while squeeze_data and output_data.ndim > 2 and output_data.shape[0] == 1: while squeeze_data and output_data.ndim > 2 and output_data.shape[0] == 1:
output_data = output_data.squeeze(0) output_data = output_data.squeeze(0)
inputs = Dict(zip(map(_make_var_name, input_names), [input_data])) inputs = Dict(zip(map(make_var_name, input_names), [input_data]))
outputs = Dict(zip(map(_make_var_name, output_name), [output_data])) outputs = Dict(zip(map(make_var_name, output_names), [output_data]))
np.savez(fn, inputs=inputs, outputs=outputs) # overwrite np.savez(fn, inputs=inputs, outputs=outputs) # overwrite
...@@ -15,16 +15,16 @@ from collections import OrderedDict as Dict ...@@ -15,16 +15,16 @@ from collections import OrderedDict as Dict
from glob import glob from glob import glob
def _make_var_name(name): def make_var_name(name):
""" """
make a valid variable name in Python code make a valid variable name in Python code
""" """
if name == '': assert name
return '_'
if name[0].isdigit(): if name[0].isdigit():
return 'var_' + name return 'var_' + name
for s in ' *?\\/-:': for s in ' \\|/:-': #
name = name.replace(s, '_') name = name.replace(s, '_')
if name.startswith('_'): if name.startswith('_'):
name = 'var' + name name = 'var' + name
...@@ -32,8 +32,8 @@ def _make_var_name(name): ...@@ -32,8 +32,8 @@ def _make_var_name(name):
data_dir = os.path.dirname(sys.argv[1]) data_dir = os.path.dirname(sys.argv[1])
input_names = sys.argv[2].split(':') input_names = sys.argv[2].split(',')
output_name = sys.argv[3].split(':') output_names = sys.argv[3].split(',')
squeeze_data = len(sys.argv) > 4 squeeze_data = len(sys.argv) > 4
# Load inputs # Load inputs
...@@ -58,7 +58,7 @@ for fn in glob(os.path.join(data_dir, 'output_*.pb')): ...@@ -58,7 +58,7 @@ for fn in glob(os.path.join(data_dir, 'output_*.pb')):
tensor = tensor.squeeze(0) tensor = tensor.squeeze(0)
outputs.append(tensor) outputs.append(tensor)
inputs = Dict(zip(map(_make_var_name, input_names), inputs)) inputs = Dict(zip(map(make_var_name, input_names), inputs))
outputs = Dict(zip(map(_make_var_name, output_name), outputs)) outputs = Dict(zip(map(make_var_name, output_names), outputs))
np.savez(data_dir, inputs=inputs, outputs=outputs) np.savez(data_dir, inputs=inputs, outputs=outputs)
...@@ -20,50 +20,97 @@ from onnx2fluid.torch_export_helper import export_onnx_with_validation ...@@ -20,50 +20,97 @@ from onnx2fluid.torch_export_helper import export_onnx_with_validation
prefix = 'sample_' prefix = 'sample_'
idx = 0 idx = 0
######### example: RNN ######## ######## example: RNN cell ########
#
#class Model(nn.Module):
# def __init__(self):
# super(Model, self).__init__()
# self.rnn = nn.RNN(4, 6, 2)
#
# def forward(self, x):
# y = x
# y, h = self.rnn(y)
# return y
#
#
#model = Model()
#model.eval()
#xb = torch.rand((2, 3, 4))
#yp = model(xb)
#idx += 1
#print('index: ', idx)
#export_onnx_with_validation(model, (xb, ), prefix + str(idx),
# ['x'], ['y'],
# verbose=True, training=False)
######### example: random ########
# class Model(nn.Module):
#class Model(nn.Module): def __init__(self):
# def __init__(self): super(Model, self).__init__()
# super(Model, self).__init__() self.gru = nn.GRUCell(6, 5)
# self.lstm = nn.LSTMCell(5, 4)
# def forward(self, x):
# y = torch.rand((2, 3)) # + torch.rand_like(xb) def forward(self, x, h1, h2, c2):
# y = y + torch.randn((2, 3)) # + torch.randn_like(xb) h = self.gru(x, h1)
# return y h, c = self.lstm(h, (h2, c2))
# return h, c
#
#model = Model()
#model.eval() model = Model()
#xb = torch.rand((2, 3)) model.eval()
#yp = model(xb) xb = torch.rand((7, 6))
#idx += 1 h1 = torch.zeros((7, 5))
#print('index: ', idx) h2 = torch.zeros((7, 4))
#export_onnx_with_validation(model, (xb, ), prefix + str(idx), c2 = torch.zeros((7, 4))
# ['x'], ['y'], yp = model(xb, h1, h2, c2)
# verbose=True, training=False) idx += 1
print('index: ', idx)
export_onnx_with_validation(model, [xb, h1, h2, c2],
prefix + str(idx), ['x', 'h1', 'h2', 'c2'],
['h', 'c'],
verbose=True,
training=False)
######## example: RNN ########
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.gru = nn.GRU(6, 5, 3)
self.lstm = nn.LSTM(5, 4, 2)
def forward(self, x, h1, h2, c2):
y, h1 = self.gru(x, h1)
y, (h2, c2) = self.lstm(y, (h2, c2))
return y
model = Model()
model.eval()
xb = torch.rand((8, 1, 6))
h1 = torch.zeros((3, 1, 5))
h2 = torch.zeros((2, 1, 4))
c2 = torch.zeros((2, 1, 4))
yp = model(xb, h1, h2, c2)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, [xb, h1, h2, c2],
prefix + str(idx), ['x', 'h1', 'h2', 'c2'], ['y'],
verbose=True,
training=False)
######## example: random ########
"""
symbolic registration:
def rand(g, *shapes):
shapes_list = list(shapes)
shape = _maybe_get_const(shapes_list[0], "is")
return g.op('RandomUniform', shape_i=shape)
"""
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
y = torch.rand((2, 3)) # + torch.rand_like(x)
y = y + torch.randn((2, 3)) # + torch.randn_like(x)
y = y + x
return y
model = Model()
model.eval()
xb = torch.rand((2, 3))
yp = model(xb)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, [xb],
prefix + str(idx), ['x'], ['y'],
verbose=True,
training=False)
######## example: fc ######## ######## example: fc ########
...@@ -85,11 +132,10 @@ xb = torch.rand((2, 3)) ...@@ -85,11 +132,10 @@ xb = torch.rand((2, 3))
yp = model(xb) yp = model(xb)
idx += 1 idx += 1
print('index: ', idx) print('index: ', idx)
export_onnx_with_validation( export_onnx_with_validation(model, [xb],
model, (xb, ), prefix + str(idx), ['x'], ['y'],
prefix + str(idx), ['x'], ['y'], verbose=True,
verbose=True, training=False)
training=False)
######## example: compare ######## ######## example: compare ########
...@@ -113,13 +159,19 @@ xb1 = torch.rand((2, 3)) ...@@ -113,13 +159,19 @@ xb1 = torch.rand((2, 3))
ya, yb, yc = model(xb0, xb1) ya, yb, yc = model(xb0, xb1)
idx += 1 idx += 1
print('index: ', idx) print('index: ', idx)
export_onnx_with_validation( export_onnx_with_validation(model, [xb0, xb1],
model, (xb0, xb1), prefix + str(idx), ['x0', 'x1'], ['ya', 'yb', 'yc'],
prefix + str(idx), ['x0', 'x1'], ['ya', 'yb', 'yc'], verbose=True,
verbose=True, training=False)
training=False)
######## example: affine_grid ######## ######## example: affine_grid ########
"""
symbolic registration:
@parse_args('v', 'is')
def affine_grid_generator(g, theta, size):
return g.op('AffineGrid', theta, size_i=size)
"""
class Model(nn.Module): class Model(nn.Module):
...@@ -137,11 +189,10 @@ theta = torch.rand((2, 2, 3)) ...@@ -137,11 +189,10 @@ theta = torch.rand((2, 2, 3))
grid = model(theta) grid = model(theta)
idx += 1 idx += 1
print('index: ', idx) print('index: ', idx)
export_onnx_with_validation( export_onnx_with_validation(model, (theta, ),
model, (theta, ), prefix + str(idx), ['theta'], ['grid'],
prefix + str(idx), ['theta'], ['grid'], verbose=True,
verbose=True, training=False)
training=False)
######## example: conv2d_transpose ######## ######## example: conv2d_transpose ########
...@@ -165,11 +216,10 @@ xb = torch.rand((2, 3, 4, 5)) ...@@ -165,11 +216,10 @@ xb = torch.rand((2, 3, 4, 5))
yp = model(xb) yp = model(xb)
idx += 1 idx += 1
print('index: ', idx) print('index: ', idx)
export_onnx_with_validation( export_onnx_with_validation(model, [xb],
model, (xb, ), prefix + str(idx), ['x'], ['y'],
prefix + str(idx), ['x'], ['y'], verbose=True,
verbose=True, training=False)
training=False)
######## example: conv2d ######## ######## example: conv2d ########
...@@ -179,7 +229,7 @@ class Model(nn.Module): ...@@ -179,7 +229,7 @@ class Model(nn.Module):
super(Model, self).__init__() super(Model, self).__init__()
self.conv = nn.Conv2d(3, 8, 3) self.conv = nn.Conv2d(3, 8, 3)
self.batch_norm = nn.BatchNorm2d(8) self.batch_norm = nn.BatchNorm2d(8)
self.pool = nn.AdaptiveAvgPool2d(2) self.pool = nn.AdaptiveAvgPool2d(1)
def forward(self, x): def forward(self, x):
y = x y = x
...@@ -195,11 +245,10 @@ xb = torch.rand((2, 3, 4, 5)) ...@@ -195,11 +245,10 @@ xb = torch.rand((2, 3, 4, 5))
yp = model(xb) yp = model(xb)
idx += 1 idx += 1
print('index: ', idx) print('index: ', idx)
export_onnx_with_validation( export_onnx_with_validation(model, [xb],
model, (xb, ), prefix + str(idx), ['x'], ['y'],
prefix + str(idx), ['x'], ['y'], verbose=True,
verbose=True, training=False)
training=False)
######### example: conv1d ######## ######### example: conv1d ########
# #
...@@ -220,9 +269,10 @@ export_onnx_with_validation( ...@@ -220,9 +269,10 @@ export_onnx_with_validation(
#yp = model(xb) #yp = model(xb)
#idx += 1 #idx += 1
#print('index: ', idx) #print('index: ', idx)
#export_onnx_with_validation(model, (xb, ), prefix + str(idx), #export_onnx_with_validation(
# ['x'], ['y'], # model, [xb], prefix + str(idx),
# verbose=True, training=False) # ['x'], ['y'],
# verbose=True, training=False)
######## example: empty ######## ######## example: empty ########
...@@ -241,8 +291,7 @@ xb = torch.rand((2, 3)) ...@@ -241,8 +291,7 @@ xb = torch.rand((2, 3))
yp = model(xb) yp = model(xb)
idx += 1 idx += 1
print('index: ', idx) print('index: ', idx)
export_onnx_with_validation( export_onnx_with_validation(model, [xb],
model, (xb, ), prefix + str(idx), ['y'], ['y'],
prefix + str(idx), ['y'], ['y'], verbose=True,
verbose=True, training=False)
training=False)
...@@ -21,10 +21,10 @@ class double_conv(nn.Module): ...@@ -21,10 +21,10 @@ class double_conv(nn.Module):
def __init__(self, in_ch, out_ch): def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__() super(double_conv, self).__init__()
self.conv = nn.Sequential( self.conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True)) nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
...@@ -58,8 +58,8 @@ class up(nn.Module): ...@@ -58,8 +58,8 @@ class up(nn.Module):
# would be a nice idea if the upsampling could be learned too, # would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights # but my machine do not have enough memory to handle all those weights
if bilinear: if bilinear:
self.up = nn.Upsample( self.up = nn.Upsample(scale_factor=2,
scale_factor=2, mode='bilinear') #, align_corners=True) mode='bilinear') #, align_corners=True)
else: else:
self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
...@@ -131,8 +131,7 @@ model = UNet(3, 80) ...@@ -131,8 +131,7 @@ model = UNet(3, 80)
model.eval() model.eval()
xb = torch.rand((1, 3, 512, 512)) xb = torch.rand((1, 3, 512, 512))
yp = model(xb) yp = model(xb)
export_onnx_with_validation( export_onnx_with_validation(model, [xb],
model, (xb, ), 'sample_unet', ['image'], ['pred'],
'sample_unet', ['image'], ['pred'], verbose=True,
verbose=True, training=False)
training=False)
...@@ -20,188 +20,166 @@ class Yolov2(nn.Module): ...@@ -20,188 +20,166 @@ class Yolov2(nn.Module):
def __init__(self): def __init__(self):
super(Yolov2, self).__init__() super(Yolov2, self).__init__()
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(in_channels=3,
in_channels=3, out_channels=32,
out_channels=32, kernel_size=3,
kernel_size=3, stride=1,
stride=1, padding=1,
padding=1, bias=False)
bias=False)
self.batchnorm1 = nn.BatchNorm2d(32) self.batchnorm1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(in_channels=32,
in_channels=32, out_channels=64,
out_channels=64, kernel_size=3,
kernel_size=3, stride=1,
stride=1, padding=1,
padding=1, bias=False)
bias=False)
self.batchnorm2 = nn.BatchNorm2d(64) self.batchnorm2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d( self.conv3 = nn.Conv2d(in_channels=64,
in_channels=64, out_channels=128,
out_channels=128, kernel_size=3,
kernel_size=3, stride=1,
stride=1, padding=1,
padding=1, bias=False)
bias=False)
self.batchnorm3 = nn.BatchNorm2d(128) self.batchnorm3 = nn.BatchNorm2d(128)
self.conv4 = nn.Conv2d( self.conv4 = nn.Conv2d(in_channels=128,
in_channels=128, out_channels=64,
out_channels=64, kernel_size=1,
kernel_size=1, stride=1,
stride=1, padding=0,
padding=0, bias=False)
bias=False)
self.batchnorm4 = nn.BatchNorm2d(64) self.batchnorm4 = nn.BatchNorm2d(64)
self.conv5 = nn.Conv2d( self.conv5 = nn.Conv2d(in_channels=64,
in_channels=64, out_channels=128,
out_channels=128, kernel_size=3,
kernel_size=3, stride=1,
stride=1, padding=1,
padding=1, bias=False)
bias=False)
self.batchnorm5 = nn.BatchNorm2d(128) self.batchnorm5 = nn.BatchNorm2d(128)
self.conv6 = nn.Conv2d( self.conv6 = nn.Conv2d(in_channels=128,
in_channels=128, out_channels=256,
out_channels=256, kernel_size=3,
kernel_size=3, stride=1,
stride=1, padding=1,
padding=1, bias=False)
bias=False)
self.batchnorm6 = nn.BatchNorm2d(256) self.batchnorm6 = nn.BatchNorm2d(256)
self.conv7 = nn.Conv2d( self.conv7 = nn.Conv2d(in_channels=256,
in_channels=256, out_channels=128,
out_channels=128, kernel_size=1,
kernel_size=1, stride=1,
stride=1, padding=0,
padding=0, bias=False)
bias=False)
self.batchnorm7 = nn.BatchNorm2d(128) self.batchnorm7 = nn.BatchNorm2d(128)
self.conv8 = nn.Conv2d( self.conv8 = nn.Conv2d(in_channels=128,
in_channels=128, out_channels=256,
out_channels=256, kernel_size=3,
kernel_size=3, stride=1,
stride=1, padding=1,
padding=1, bias=False)
bias=False)
self.batchnorm8 = nn.BatchNorm2d(256) self.batchnorm8 = nn.BatchNorm2d(256)
self.conv9 = nn.Conv2d( self.conv9 = nn.Conv2d(in_channels=256,
in_channels=256, out_channels=512,
out_channels=512, kernel_size=3,
kernel_size=3, stride=1,
stride=1, padding=1,
padding=1, bias=False)
bias=False)
self.batchnorm9 = nn.BatchNorm2d(512) self.batchnorm9 = nn.BatchNorm2d(512)
self.conv10 = nn.Conv2d( self.conv10 = nn.Conv2d(in_channels=512,
in_channels=512, out_channels=256,
out_channels=256, kernel_size=1,
kernel_size=1, stride=1,
stride=1, padding=0,
padding=0, bias=False)
bias=False)
self.batchnorm10 = nn.BatchNorm2d(256) self.batchnorm10 = nn.BatchNorm2d(256)
self.conv11 = nn.Conv2d( self.conv11 = nn.Conv2d(in_channels=256,
in_channels=256, out_channels=512,
out_channels=512, kernel_size=3,
kernel_size=3, stride=1,
stride=1, padding=1,
padding=1, bias=False)
bias=False)
self.batchnorm11 = nn.BatchNorm2d(512) self.batchnorm11 = nn.BatchNorm2d(512)
self.conv12 = nn.Conv2d( self.conv12 = nn.Conv2d(in_channels=512,
in_channels=512, out_channels=256,
out_channels=256, kernel_size=1,
kernel_size=1, stride=1,
stride=1, padding=0,
padding=0, bias=False)
bias=False)
self.batchnorm12 = nn.BatchNorm2d(256) self.batchnorm12 = nn.BatchNorm2d(256)
self.conv13 = nn.Conv2d( self.conv13 = nn.Conv2d(in_channels=256,
in_channels=256, out_channels=512,
out_channels=512, kernel_size=3,
kernel_size=3, stride=1,
stride=1, padding=1,
padding=1, bias=False)
bias=False)
self.batchnorm13 = nn.BatchNorm2d(512) self.batchnorm13 = nn.BatchNorm2d(512)
self.conv14 = nn.Conv2d( self.conv14 = nn.Conv2d(in_channels=512,
in_channels=512, out_channels=1024,
out_channels=1024, kernel_size=3,
kernel_size=3, stride=1,
stride=1, padding=1,
padding=1, bias=False)
bias=False)
self.batchnorm14 = nn.BatchNorm2d(1024) self.batchnorm14 = nn.BatchNorm2d(1024)
self.conv15 = nn.Conv2d( self.conv15 = nn.Conv2d(in_channels=1024,
in_channels=1024, out_channels=512,
out_channels=512, kernel_size=1,
kernel_size=1, stride=1,
stride=1, padding=0,
padding=0, bias=False)
bias=False)
self.batchnorm15 = nn.BatchNorm2d(512) self.batchnorm15 = nn.BatchNorm2d(512)
self.conv16 = nn.Conv2d( self.conv16 = nn.Conv2d(in_channels=512,
in_channels=512, out_channels=1024,
out_channels=1024, kernel_size=3,
kernel_size=3, stride=1,
stride=1, padding=1,
padding=1, bias=False)
bias=False)
self.batchnorm16 = nn.BatchNorm2d(1024) self.batchnorm16 = nn.BatchNorm2d(1024)
self.conv17 = nn.Conv2d( self.conv17 = nn.Conv2d(in_channels=1024,
in_channels=1024, out_channels=512,
out_channels=512, kernel_size=1,
kernel_size=1, stride=1,
stride=1, padding=0,
padding=0, bias=False)
bias=False)
self.batchnorm17 = nn.BatchNorm2d(512) self.batchnorm17 = nn.BatchNorm2d(512)
self.conv18 = nn.Conv2d( self.conv18 = nn.Conv2d(in_channels=512,
in_channels=512, out_channels=1024,
out_channels=1024, kernel_size=3,
kernel_size=3, stride=1,
stride=1, padding=1,
padding=1, bias=False)
bias=False)
self.batchnorm18 = nn.BatchNorm2d(1024) self.batchnorm18 = nn.BatchNorm2d(1024)
self.conv19 = nn.Conv2d( self.conv19 = nn.Conv2d(in_channels=1024,
in_channels=1024, out_channels=1024,
out_channels=1024, kernel_size=3,
kernel_size=3, stride=1,
stride=1, padding=1,
padding=1, bias=False)
bias=False)
self.batchnorm19 = nn.BatchNorm2d(1024) self.batchnorm19 = nn.BatchNorm2d(1024)
self.conv20 = nn.Conv2d( self.conv20 = nn.Conv2d(in_channels=1024,
in_channels=1024, out_channels=1024,
out_channels=1024, kernel_size=3,
kernel_size=3, stride=1,
stride=1, padding=1,
padding=1, bias=False)
bias=False)
self.batchnorm20 = nn.BatchNorm2d(1024) self.batchnorm20 = nn.BatchNorm2d(1024)
self.conv21 = nn.Conv2d( self.conv21 = nn.Conv2d(in_channels=3072,
in_channels=3072, out_channels=1024,
out_channels=1024, kernel_size=3,
kernel_size=3, stride=1,
stride=1, padding=1,
padding=1, bias=False)
bias=False)
self.batchnorm21 = nn.BatchNorm2d(1024) self.batchnorm21 = nn.BatchNorm2d(1024)
self.conv22 = nn.Conv2d( self.conv22 = nn.Conv2d(in_channels=1024,
in_channels=1024, out_channels=125,
out_channels=125, kernel_size=1,
kernel_size=1, stride=1,
stride=1, padding=0)
padding=0)
def reorg_layer(self, x): def reorg_layer(self, x):
stride = 2 stride = 2
...@@ -227,14 +205,14 @@ class Yolov2(nn.Module): ...@@ -227,14 +205,14 @@ class Yolov2(nn.Module):
return passthrough return passthrough
def forward(self, x): def forward(self, x):
out = F.max_pool2d( out = F.max_pool2d(F.leaky_relu(self.batchnorm1(self.conv1(x)),
F.leaky_relu(self.batchnorm1(self.conv1(x)), negative_slope=0.1), negative_slope=0.1),
2, 2,
stride=2) stride=2)
out = F.max_pool2d( out = F.max_pool2d(F.leaky_relu(self.batchnorm2(self.conv2(out)),
F.leaky_relu(self.batchnorm2(self.conv2(out)), negative_slope=0.1), negative_slope=0.1),
2, 2,
stride=2) stride=2)
out = F.leaky_relu(self.batchnorm3(self.conv3(out)), negative_slope=0.1) out = F.leaky_relu(self.batchnorm3(self.conv3(out)), negative_slope=0.1)
out = F.leaky_relu(self.batchnorm4(self.conv4(out)), negative_slope=0.1) out = F.leaky_relu(self.batchnorm4(self.conv4(out)), negative_slope=0.1)
...@@ -247,36 +225,36 @@ class Yolov2(nn.Module): ...@@ -247,36 +225,36 @@ class Yolov2(nn.Module):
out = F.max_pool2d(out, 2, stride=2) out = F.max_pool2d(out, 2, stride=2)
out = F.leaky_relu(self.batchnorm9(self.conv9(out)), negative_slope=0.1) out = F.leaky_relu(self.batchnorm9(self.conv9(out)), negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm10(self.conv10(out)),
self.batchnorm10(self.conv10(out)), negative_slope=0.1) negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm11(self.conv11(out)),
self.batchnorm11(self.conv11(out)), negative_slope=0.1) negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm12(self.conv12(out)),
self.batchnorm12(self.conv12(out)), negative_slope=0.1) negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm13(self.conv13(out)),
self.batchnorm13(self.conv13(out)), negative_slope=0.1) negative_slope=0.1)
passthrough = self.reorg_layer(out) passthrough = self.reorg_layer(out)
out = F.max_pool2d(out, 2, stride=2) out = F.max_pool2d(out, 2, stride=2)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm14(self.conv14(out)),
self.batchnorm14(self.conv14(out)), negative_slope=0.1) negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm15(self.conv15(out)),
self.batchnorm15(self.conv15(out)), negative_slope=0.1) negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm16(self.conv16(out)),
self.batchnorm16(self.conv16(out)), negative_slope=0.1) negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm17(self.conv17(out)),
self.batchnorm17(self.conv17(out)), negative_slope=0.1) negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm18(self.conv18(out)),
self.batchnorm18(self.conv18(out)), negative_slope=0.1) negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm19(self.conv19(out)),
self.batchnorm19(self.conv19(out)), negative_slope=0.1) negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm20(self.conv20(out)),
self.batchnorm20(self.conv20(out)), negative_slope=0.1) negative_slope=0.1)
out = torch.cat([passthrough, out], 1) out = torch.cat([passthrough, out], 1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm21(self.conv21(out)),
self.batchnorm21(self.conv21(out)), negative_slope=0.1) negative_slope=0.1)
out = self.conv22(out) out = self.conv22(out)
return out return out
...@@ -286,8 +264,7 @@ model = Yolov2() ...@@ -286,8 +264,7 @@ model = Yolov2()
model.eval() model.eval()
xb = torch.rand((1, 3, 224, 224)) xb = torch.rand((1, 3, 224, 224))
yp = model(xb) yp = model(xb)
export_onnx_with_validation( export_onnx_with_validation(model, [xb],
model, (xb, ), 'sample_yolov2', ['image'], ['pred'],
'sample_yolov2', ['image'], ['pred'], verbose=True,
verbose=True, training=False)
training=False)
...@@ -92,9 +92,17 @@ parser.add_argument( ...@@ -92,9 +92,17 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
'--rtol', '--rtol',
type=float, type=float,
default=1e-4, default=1e-2,
help='assertion relative tolerance for validation', help='assertion relative tolerance for validation',
) )
parser.add_argument(
'--infer_inputs',
'-i',
nargs='?',
default=None,
const='',
help='perform type-shape inference with given input names and re-save model',
)
args = parser.parse_args() args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s' logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
......
...@@ -22,7 +22,6 @@ __all__ = [ ...@@ -22,7 +22,6 @@ __all__ = [
'main', 'main',
] ]
DEFAULT_ONNX_OPSET_VERSION = 9
DEFAULT_MODEL_MODULE = 'model' DEFAULT_MODEL_MODULE = 'model'
DEFAULT_MODEL_FUNC = 'inference' DEFAULT_MODEL_FUNC = 'inference'
...@@ -30,6 +29,7 @@ DEFAULT_MODEL_FUNC = 'inference' ...@@ -30,6 +29,7 @@ DEFAULT_MODEL_FUNC = 'inference'
def main(**kwargs): def main(**kwargs):
"""主程序入口""" """主程序入口"""
from .conversion import DEFAULT_ONNX_OPSET_VERSION
from .conversion import convert from .conversion import convert
logger = logging.getLogger('onnx2fluid') logger = logging.getLogger('onnx2fluid')
...@@ -44,41 +44,50 @@ def main(**kwargs): ...@@ -44,41 +44,50 @@ def main(**kwargs):
if save_dir else basepath) + shutil.os.sep if save_dir else basepath) + shutil.os.sep
model_basename = DEFAULT_MODEL_MODULE + '.py' model_basename = DEFAULT_MODEL_MODULE + '.py'
model_func_name = DEFAULT_MODEL_FUNC model_func_name = DEFAULT_MODEL_FUNC
onnx_opset_version = DEFAULT_ONNX_OPSET_VERSION
onnx_opset_pedantic = kwargs.pop('pedantic', True) onnx_opset_pedantic = kwargs.pop('pedantic', True)
onnx_skip_version_conversion = kwargs.pop('skip_version_conversion', False) skip_version_conversion = kwargs.pop('skip_version_conversion', False)
onnx_opset_version = None if skip_version_conversion else DEFAULT_ONNX_OPSET_VERSION
# convert # convert
convert( convert(filename,
filename, save_dir,
save_dir, model_basename=model_basename,
model_basename=model_basename, model_func_name=model_func_name,
model_func_name=model_func_name, onnx_opset_version=onnx_opset_version,
onnx_opset_version=onnx_opset_version, onnx_opset_pedantic=onnx_opset_pedantic,
onnx_opset_pedantic=onnx_opset_pedantic, **kwargs)
onnx_skip_version_conversion=onnx_skip_version_conversion,
**kwargs)
# validate # validate
passed = True passed = True
golden_data_filename = kwargs.pop('test_data', '') golden_data_filename = kwargs.pop('test_data', '')
if golden_data_filename: infer_inputs = kwargs.pop('infer_inputs', None)
save_inference_model = infer_inputs is not None
if golden_data_filename or save_inference_model:
from .validation import validate from .validation import validate
if save_inference_model:
inference_input_names = infer_inputs.split(',')
else:
inference_input_names = None
logger.info('starting validation on desc ...') logger.info('starting validation on desc ...')
passed &= validate( passed &= validate(shutil.os.path.join(save_dir, '__model__'),
shutil.os.path.join(save_dir, '__model__'), golden_data_filename, golden_data_filename=golden_data_filename,
**kwargs) save_inference_model=save_inference_model,
inference_input_names=inference_input_names,
**kwargs)
logger.info('starting validation on code ...') logger.info('starting validation on code ...')
passed &= validate( # this re-generate desc proto with Python code when debug on
shutil.os.path.join(save_dir, model_basename), passed &= validate(shutil.os.path.join(save_dir, model_basename),
golden_data_filename, golden_data_filename=golden_data_filename,
model_func_name=model_func_name, model_func_name=model_func_name,
**kwargs) save_inference_model=save_inference_model,
inference_input_names=inference_input_names,
**kwargs)
if not passed: if not passed:
logger.error('validation failed, exit') logger.fatal('validation failed, exit')
return return
# create zip file # create zip file
...@@ -111,19 +120,17 @@ if __name__ == '__main__': ...@@ -111,19 +120,17 @@ if __name__ == '__main__':
from onnx2fluid.cmdline import main from onnx2fluid.cmdline import main
main( main(model=['../examples/t1.onnx'],
model=['../examples/t1.onnx'], output_dir='/tmp/export/',
output_dir='/tmp/export/', embed_params=False,
embed_params=False, pedantic=False,
pedantic=False, test_data='../examples/t1.npz',
test_data='../examples/t1.npz', debug=True)
debug=True)
main(model=['../examples/inception_v2/model.onnx'],
main( output_dir='/tmp/export/',
model=['../examples/inception_v2/model.onnx'], embed_params=True,
output_dir='/tmp/export/', pedantic=False,
embed_params=True, skip_version_conversion=False,
pedantic=False, test_data='../examples/inception_v2/test_data_set_2.npz',
skip_version_conversion=False, debug=True)
test_data='../examples/inception_v2/test_data_set_2.npz',
debug=True)
...@@ -14,53 +14,72 @@ __all__ = [ ...@@ -14,53 +14,72 @@ __all__ = [
'convert', 'convert',
] ]
DEFAULT_ONNX_OPSET_VERSION = 9
def make_var_name(name):
"""
make a valid variable name in Python code and filename in filesystem
"""
if name == '':
return ''
if name[0].isdigit():
return 'var_' + name
for s in ' \\|/:.-':
name = name.replace(s, '_')
if name.startswith('_'):
name = 'var' + name
return name
def convert(onnx_model_filename, def convert(onnx_model_filename,
save_dir, save_dir,
model_basename='model.py', model_basename='model.py',
model_func_name='inference', model_func_name='inference',
embed_params=False, embed_params=False,
onnx_opset_version=9, onnx_opset_version=None,
onnx_opset_pedantic=True, onnx_opset_pedantic=True,
onnx_skip_version_conversion=False,
debug=False, debug=False,
**kwargs): **kwargs):
""" """
convert an ONNX model to Paddle fluid Python code and desc pb convert an ONNX model to Paddle fluid Python code and desc pb
""" """
assert isinstance(onnx_model_filename, str)
assert isinstance(save_dir, str)
assert isinstance(model_basename, str)
assert isinstance(model_func_name, str)
assert onnx_opset_version is None or isinstance(onnx_opset_version, int)
import onnx import onnx
from onnx.checker import ValidationError from onnx.checker import ValidationError
from onnx.checker import check_model from onnx.checker import check_model
from onnx.utils import polish_model
from onnx.version_converter import convert_version from onnx.version_converter import convert_version
from .onnx_utils import DEFAULT_OP_DOMAIN from .onnx_utils import DEFAULT_OP_DOMAIN
from .onnx_utils import graph_ops, graph_weights from .onnx_utils import graph_ops, graph_weights
from .onnx_utils import inferred_model_value_info from .onnx_utils import inferred_model_value_info
from .onnx_utils import optimize_model_skip_op_for_inference from .onnx_utils import polish_model
from .onnx_utils import optimize_model_strip_initializer
from .onnx_utils import optimize_model_cast, optimize_model_slice
from .writer import Program, Writer from .writer import Program, Writer
from .writer import make_var_name
logger = logging.getLogger('convert') logger = logging.getLogger('convert')
# prepare onnx model # prepare onnx model
logger.info('loading model: %s ...', onnx_model_filename) logger.info('loading model: %s ...', onnx_model_filename)
onnx_model = onnx.load(onnx_model_filename) onnx_model = onnx.load(onnx_model_filename)
try: try:
logger.info('checking model ...') logger.info('checking model ...')
check_model(onnx_model) check_model(onnx_model)
if onnx_skip_version_conversion: # WORKAROUND: RuntimeError: No Adapter For OP if onnx_opset_version is None: # WORKAROUND: RuntimeError: No Adapter For OP
logger.debug('assumed opset version: %d', onnx_opset_version)
logger.warning( logger.warning(
'opset conversion skipped for onnx_opset_pedantic is OFF') 'opset conversion skipped for onnx_opset_pedantic is OFF')
logger.info('assumed opset version: %d', DEFAULT_ONNX_OPSET_VERSION)
else: else:
logger.debug('using opset version: %d', onnx_opset_version) logger.info('using opset version: %d', onnx_opset_version)
onnx_model = convert_version(onnx_model, onnx_opset_version) onnx_model = convert_version(onnx_model, onnx_opset_version)
onnx_model = polish_model(onnx_model)
except ValidationError as e: except ValidationError as e:
if onnx_opset_pedantic: if onnx_opset_pedantic:
raise e raise e
...@@ -68,13 +87,11 @@ def convert(onnx_model_filename, ...@@ -68,13 +87,11 @@ def convert(onnx_model_filename,
logger.warning('due to onnx_opset_pedantic is OFF') logger.warning('due to onnx_opset_pedantic is OFF')
logger.warning('the ONNX model sanity checking error is suppressed') logger.warning('the ONNX model sanity checking error is suppressed')
logger.warning('value_info inferring may be uncompleted') logger.warning('value_info inferring may be uncompleted')
# onnx model optimization # onnx model optimization
logger.info('model has %d ops', len(onnx_model.graph.node)) logger.info('model has %d ops', len(onnx_model.graph.node))
logger.info('optimizing model ...') logger.info('optimizing model ...')
onnx_model = optimize_model_skip_op_for_inference(onnx_model) onnx_model = polish_model(onnx_model, checking=onnx_opset_pedantic)
onnx_model = optimize_model_strip_initializer(onnx_model)
onnx_model = optimize_model_cast(onnx_model)
onnx_model = optimize_model_slice(onnx_model)
# prepare filesystem # prepare filesystem
shutil.rmtree(save_dir, ignore_errors=True) shutil.rmtree(save_dir, ignore_errors=True)
...@@ -83,30 +100,31 @@ def convert(onnx_model_filename, ...@@ -83,30 +100,31 @@ def convert(onnx_model_filename,
# DEBUG: # DEBUG:
if debug: if debug:
model = onnx.shape_inference.infer_shapes(onnx_model)
debug_model_filename, _ = shutil.os.path.splitext(onnx_model_filename) debug_model_filename, _ = shutil.os.path.splitext(onnx_model_filename)
onnx.save(model, debug_model_filename + '.optimized_and_inffered.onnx') onnx.save(onnx_model, debug_model_filename + '.polished.onnx')
# onnx.save(model, '/tmp/export/optimized_and_inffered.onnx')
# I/O instances # I/O instances
onnx_graph = onnx_model.graph onnx_graph = onnx_model.graph
fluid_program = Program() fluid_program = Program()
fluid_writer = Writer() fluid_writer = Writer()
# model components # model components
# graph_name = onnx_graph.name inp_vars = [make_var_name(value.name) for value in onnx_graph.input]
graph_inputs = [value.name for value in onnx_graph.input] out_vars = [make_var_name(value.name) for value in onnx_graph.output]
graph_outputs = [value.name for value in onnx_graph.output] par_vars = []
graph_params = [] value_infos = inferred_model_value_info(onnx_model)
graph_value_infos = inferred_model_value_info(onnx_model) value_infos = {
make_var_name(key): value
for key, value in value_infos.items()
}
# prepare additional value_info # prepare additional value_info
# for weights # for weights
for name, weight in graph_weights(onnx_graph): for name, weight in graph_weights(onnx_graph):
value_info = graph_value_infos[name] var_name = make_var_name(name)
value_info['embeded_as'] = [] value_info = value_infos[var_name]
value_info['lod'] = [0]
value_info['embedded_as'] = []
value_info['get_weight'] = (lambda w: lambda: w.tolist())( value_info['get_weight'] = (lambda w: lambda: w.tolist())(
weight) # lazy getter weight) # lazy getter
...@@ -114,21 +132,25 @@ def convert(onnx_model_filename, ...@@ -114,21 +132,25 @@ def convert(onnx_model_filename,
# op set conversion # op set conversion
# topo = 'backward' if embed_params else 'forward' # topo = 'backward' if embed_params else 'forward'
topo = 'forward' topo = 'forward'
for name, domain, op_type, inputs, outputs, attrs in graph_ops( for name, domain, op_type, inputs, outputs, attrs in graph_ops(onnx_graph,
onnx_graph, topo=topo): topo=topo):
logger.debug('translating op %s %s::%s ...', name, domain, op_type) op_name = make_var_name(name)
inputs = list(map(make_var_name, inputs))
outputs = list(map(make_var_name, outputs))
logger.debug('translating op %s(%s) %s::%s ...', name, op_name, domain,
op_type)
if domain == DEFAULT_OP_DOMAIN: if domain == DEFAULT_OP_DOMAIN:
domain = '' domain = ''
try: try:
fluid_writer.emit_op( fluid_writer.emit_op(
fluid_program, fluid_program,
name, op_name,
domain, domain,
op_type, op_type,
inputs, inputs,
outputs, outputs,
attrs, attrs,
graph_value_infos, value_infos,
embed_params=embed_params, embed_params=embed_params,
) )
except BaseException as e: except BaseException as e:
...@@ -140,53 +162,74 @@ def convert(onnx_model_filename, ...@@ -140,53 +162,74 @@ def convert(onnx_model_filename,
logger.info('%d ops in, %d ops out', len(onnx_graph.node), logger.info('%d ops in, %d ops out', len(onnx_graph.node),
len(fluid_program.op_descs)) len(fluid_program.op_descs))
# type-shape info copy
for var_name, value_info in value_infos.items():
fluid_program.VarTypeShapeInfo(var_name, value_info,
remove_batch=False) #
bad_vars = []
for var_name, var_desc in fluid_program.var_descs.items():
if not var_desc.type.lod_tensor.HasField('tensor'):
bad_vars.append(var_name)
if bad_vars:
logger.warning('type-shape not infered for var %s ...',
', '.join(bad_vars[:5]))
logger.warning('this causes little problem for PaddlePaddle, '
'but Paddle Mobile may not infer correctly')
logger.warning('please consider running validation with -i '
'to invoke type-shape inference in PaddlePaddle')
# weight writer # weight writer
for name, weight in graph_weights(onnx_graph): for name, weight in graph_weights(onnx_graph):
graph_params.append(name) var_name = make_var_name(name)
value_info = graph_value_infos[name] par_vars.append(var_name)
var_names = value_info.get('embeded_as', []) value_info = value_infos[var_name]
if var_names: embedded_names = value_info.get('embedded_as', [])
if len(var_names) > 1: if embedded_names:
if len(embedded_names) > 1:
logger.info( logger.info(
'weight %s is shared between ops, more disk space will be consumed', 'weight %s is shared between ops, more disk space will be consumed',
name) name)
logger.debug('saving weight %s(%s[%d], %dB) as %s ...', name, logger.debug('saving weight %s(%s[%d], %dB) as %s ...', name,
weight.dtype, weight.size, weight.nbytes, var_names) weight.dtype, weight.size, weight.nbytes,
for var_name in var_names: # multiple references embedded_names)
fluid_writer.write_weight( for embedded_name in embedded_names: # multiple references
weight, shutil.os.path.join(save_dir, var_name)) fluid_writer.write_weight(weight,
shutil.os.path.join(
save_dir, embedded_name),
lod=value_info['lod'])
else: else:
logger.debug('saving weight %s(%s[%d], %dB) to %s ...', name, logger.debug('saving weight %s(%s[%d], %dB) to %s ...', name,
weight.dtype, weight.size, weight.nbytes, weight.dtype, weight.size, weight.nbytes, var_name)
make_var_name(name)) fluid_writer.write_weight(weight,
fluid_writer.write_weight( shutil.os.path.join(save_dir, var_name),
weight, shutil.os.path.join(save_dir, make_var_name(name))) lod=value_info['lod'])
fluid_writer.emit_param(fluid_program, name, value_info) fluid_writer.emit_param(fluid_program, var_name, value_info)
param_codes = fluid_program.codes param_codes = fluid_program.codes
fluid_program.codes = [] fluid_program.codes = []
logger.info('%d weights converted', len(graph_params)) logger.info('%d weights converted', len(par_vars))
# input writer # input writer
external_inputs = [] external_inputs = []
for name in graph_inputs: for var_name in inp_vars:
if name not in graph_params: if var_name not in par_vars:
value_info = graph_value_infos[name] value_info = value_infos[var_name]
assert value_info['external'] assert value_info['external']
external_inputs.append(name) external_inputs.append(var_name)
fluid_writer.emit_inputs( fluid_writer.emit_inputs(fluid_program,
fluid_program, external_inputs, graph_value_infos, external_inputs,
remove_batch=False) # TODO: value_infos,
remove_batch=False) # TODO:
input_codes = fluid_program.codes input_codes = fluid_program.codes
fluid_program.codes = [] fluid_program.codes = []
logger.info('%d inputs converted', len(external_inputs)) logger.info('%d inputs converted', len(external_inputs))
# output writer # output writer
external_outputs = [] external_outputs = []
for name in graph_outputs: for var_name in out_vars:
if name not in graph_params: if var_name not in par_vars:
value_info = graph_value_infos[name] value_info = value_infos[var_name]
assert value_info['external'] assert value_info['external']
external_outputs.append(name) external_outputs.append(var_name)
fluid_writer.emit_outputs(fluid_program, external_outputs) fluid_writer.emit_outputs(fluid_program, external_outputs)
output_codes = [''] + fluid_program.codes # add an empty line output_codes = [''] + fluid_program.codes # add an empty line
fluid_program.codes = [] fluid_program.codes = []
...@@ -194,10 +237,18 @@ def convert(onnx_model_filename, ...@@ -194,10 +237,18 @@ def convert(onnx_model_filename,
# code generation # code generation
header_codes = fluid_writer.header_code( header_codes = fluid_writer.header_code(
model_func_name, 'From: {}'.format(onnx_model_filename)) model_func_name,
'From: {}'.format(onnx_model_filename),
)
code_filename = shutil.os.path.join(save_dir, model_basename) code_filename = shutil.os.path.join(save_dir, model_basename)
fluid_writer.write_code_file(code_filename, header_codes, input_codes, fluid_writer.write_code_file(
param_codes, op_codes, output_codes) code_filename,
header_codes,
input_codes,
param_codes,
op_codes,
output_codes,
)
logger.info('code saved to %s, factory function: %s', code_filename, logger.info('code saved to %s, factory function: %s', code_filename,
model_func_name) model_func_name)
...@@ -206,19 +257,16 @@ def convert(onnx_model_filename, ...@@ -206,19 +257,16 @@ def convert(onnx_model_filename,
fluid_writer.write_desc_file( fluid_writer.write_desc_file(
desc_filename, desc_filename,
op_descs=fluid_program.op_descs, op_descs=fluid_program.op_descs,
var_descs=fluid_program.var_descs, var_descs=list(fluid_program.var_descs.values()),
) )
logger.info('program saved to %s', desc_filename) logger.info('program saved to %s', desc_filename)
logger.info('conversion finished') logger.info('conversion finished')
if __name__ == '__main__':
del convert
def main():
import argparse import argparse
from onnx2fluid.conversion import convert
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='onnx2fluid.convert', description='onnx2fluid.convert',
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
...@@ -283,10 +331,17 @@ if __name__ == '__main__': ...@@ -283,10 +331,17 @@ if __name__ == '__main__':
pedantic = args.pedantic pedantic = args.pedantic
skip_version_conversion = args.skip_version_conversion skip_version_conversion = args.skip_version_conversion
convert( convert(model_filename,
model_filename, save_dir,
save_dir, embed_params=embed_params,
embed_params=embed_params, onnx_opset_pedantic=pedantic,
onnx_opset_pedantic=pedantic, onnx_skip_version_conversion=skip_version_conversion,
onnx_skip_version_conversion=skip_version_conversion, debug=debug)
debug=debug)
if __name__ == '__main__':
del convert
from onnx2fluid.conversion import convert
main()
...@@ -28,30 +28,66 @@ _ATTRTYPE = _descriptor.EnumDescriptor( ...@@ -28,30 +28,66 @@ _ATTRTYPE = _descriptor.EnumDescriptor(
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
values=[ values=[
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(name='INT',
name='INT', index=0, number=0, options=None, type=None), index=0,
_descriptor.EnumValueDescriptor( number=0,
name='FLOAT', index=1, number=1, options=None, type=None), options=None,
_descriptor.EnumValueDescriptor( type=None),
name='STRING', index=2, number=2, options=None, type=None), _descriptor.EnumValueDescriptor(name='FLOAT',
_descriptor.EnumValueDescriptor( index=1,
name='INTS', index=3, number=3, options=None, type=None), number=1,
_descriptor.EnumValueDescriptor( options=None,
name='FLOATS', index=4, number=4, options=None, type=None), type=None),
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(name='STRING',
name='STRINGS', index=5, number=5, options=None, type=None), index=2,
_descriptor.EnumValueDescriptor( number=2,
name='BOOLEAN', index=6, number=6, options=None, type=None), options=None,
_descriptor.EnumValueDescriptor( type=None),
name='BOOLEANS', index=7, number=7, options=None, type=None), _descriptor.EnumValueDescriptor(name='INTS',
_descriptor.EnumValueDescriptor( index=3,
name='BLOCK', index=8, number=8, options=None, type=None), number=3,
_descriptor.EnumValueDescriptor( options=None,
name='LONG', index=9, number=9, options=None, type=None), type=None),
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(name='FLOATS',
name='BLOCKS', index=10, number=10, options=None, type=None), index=4,
_descriptor.EnumValueDescriptor( number=4,
name='LONGS', index=11, number=11, options=None, type=None), options=None,
type=None),
_descriptor.EnumValueDescriptor(name='STRINGS',
index=5,
number=5,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='BOOLEAN',
index=6,
number=6,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='BOOLEANS',
index=7,
number=7,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='BLOCK',
index=8,
number=8,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='LONG',
index=9,
number=9,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='BLOCKS',
index=10,
number=10,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='LONGS',
index=11,
number=11,
options=None,
type=None),
], ],
containing_type=None, containing_type=None,
options=None, options=None,
...@@ -80,53 +116,111 @@ _VARTYPE_TYPE = _descriptor.EnumDescriptor( ...@@ -80,53 +116,111 @@ _VARTYPE_TYPE = _descriptor.EnumDescriptor(
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
values=[ values=[
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(name='BOOL',
name='BOOL', index=0, number=0, options=None, type=None), index=0,
_descriptor.EnumValueDescriptor( number=0,
name='INT16', index=1, number=1, options=None, type=None), options=None,
_descriptor.EnumValueDescriptor( type=None),
name='INT32', index=2, number=2, options=None, type=None), _descriptor.EnumValueDescriptor(name='INT16',
_descriptor.EnumValueDescriptor( index=1,
name='INT64', index=3, number=3, options=None, type=None), number=1,
_descriptor.EnumValueDescriptor( options=None,
name='FP16', index=4, number=4, options=None, type=None), type=None),
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(name='INT32',
name='FP32', index=5, number=5, options=None, type=None), index=2,
_descriptor.EnumValueDescriptor( number=2,
name='FP64', index=6, number=6, options=None, type=None), options=None,
_descriptor.EnumValueDescriptor( type=None),
name='SIZE_T', index=7, number=19, options=None, type=None), _descriptor.EnumValueDescriptor(name='INT64',
_descriptor.EnumValueDescriptor( index=3,
name='UINT8', index=8, number=20, options=None, type=None), number=3,
_descriptor.EnumValueDescriptor( options=None,
name='INT8', index=9, number=21, options=None, type=None), type=None),
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(name='FP16',
name='LOD_TENSOR', index=10, number=7, options=None, type=None), index=4,
_descriptor.EnumValueDescriptor( number=4,
name='SELECTED_ROWS', index=11, number=8, options=None, type=None), options=None,
_descriptor.EnumValueDescriptor( type=None),
name='FEED_MINIBATCH', index=12, number=9, options=None, type=None), _descriptor.EnumValueDescriptor(name='FP32',
_descriptor.EnumValueDescriptor( index=5,
name='FETCH_LIST', index=13, number=10, options=None, type=None), number=5,
_descriptor.EnumValueDescriptor( options=None,
name='STEP_SCOPES', index=14, number=11, options=None, type=None), type=None),
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(name='FP64',
name='LOD_RANK_TABLE', index=15, number=12, options=None, index=6,
type=None), number=6,
_descriptor.EnumValueDescriptor( options=None,
name='LOD_TENSOR_ARRAY', type=None),
index=16, _descriptor.EnumValueDescriptor(name='SIZE_T',
number=13, index=7,
options=None, number=19,
type=None), options=None,
_descriptor.EnumValueDescriptor( type=None),
name='PLACE_LIST', index=17, number=14, options=None, type=None), _descriptor.EnumValueDescriptor(name='UINT8',
_descriptor.EnumValueDescriptor( index=8,
name='READER', index=18, number=15, options=None, type=None), number=20,
_descriptor.EnumValueDescriptor( options=None,
name='RAW', index=19, number=17, options=None, type=None), type=None),
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(name='INT8',
name='TUPLE', index=20, number=18, options=None, type=None), index=9,
number=21,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='LOD_TENSOR',
index=10,
number=7,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='SELECTED_ROWS',
index=11,
number=8,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='FEED_MINIBATCH',
index=12,
number=9,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='FETCH_LIST',
index=13,
number=10,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='STEP_SCOPES',
index=14,
number=11,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='LOD_RANK_TABLE',
index=15,
number=12,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='LOD_TENSOR_ARRAY',
index=16,
number=13,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='PLACE_LIST',
index=17,
number=14,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='READER',
index=18,
number=15,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='RAW',
index=19,
number=17,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='TUPLE',
index=20,
number=18,
options=None,
type=None),
], ],
containing_type=None, containing_type=None,
options=None, options=None,
...@@ -1480,11 +1574,10 @@ DESCRIPTOR.enum_types_by_name['AttrType'] = _ATTRTYPE ...@@ -1480,11 +1574,10 @@ DESCRIPTOR.enum_types_by_name['AttrType'] = _ATTRTYPE
Version = _reflection.GeneratedProtocolMessageType( Version = _reflection.GeneratedProtocolMessageType(
'Version', 'Version',
(_message.Message, ), (_message.Message, ),
dict( dict(DESCRIPTOR=_VERSION,
DESCRIPTOR=_VERSION, __module__='framework_pb2'
__module__='framework_pb2' # @@protoc_insertion_point(class_scope:paddle.framework.proto.Version)
# @@protoc_insertion_point(class_scope:paddle.framework.proto.Version) ))
))
_sym_db.RegisterMessage(Version) _sym_db.RegisterMessage(Version)
OpDesc = _reflection.GeneratedProtocolMessageType( OpDesc = _reflection.GeneratedProtocolMessageType(
...@@ -1601,11 +1694,10 @@ _sym_db.RegisterMessage(VarType.Tuple) ...@@ -1601,11 +1694,10 @@ _sym_db.RegisterMessage(VarType.Tuple)
VarDesc = _reflection.GeneratedProtocolMessageType( VarDesc = _reflection.GeneratedProtocolMessageType(
'VarDesc', 'VarDesc',
(_message.Message, ), (_message.Message, ),
dict( dict(DESCRIPTOR=_VARDESC,
DESCRIPTOR=_VARDESC, __module__='framework_pb2'
__module__='framework_pb2' # @@protoc_insertion_point(class_scope:paddle.framework.proto.VarDesc)
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarDesc) ))
))
_sym_db.RegisterMessage(VarDesc) _sym_db.RegisterMessage(VarDesc)
BlockDesc = _reflection.GeneratedProtocolMessageType( BlockDesc = _reflection.GeneratedProtocolMessageType(
......
...@@ -11,9 +11,11 @@ from __future__ import division ...@@ -11,9 +11,11 @@ from __future__ import division
import logging import logging
import numpy as np import numpy as np
import onnx import onnx
import onnx.optimizer as optimizer
from collections import OrderedDict as Dict # as default dict from collections import OrderedDict as Dict # as default dict
from onnx.helper import get_attribute_value, make_attribute from onnx.checker import check_model
from onnx.helper import get_attribute_value, make_attribute, strip_doc_string
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
from onnx.numpy_helper import to_array from onnx.numpy_helper import to_array
from onnx.shape_inference import infer_shapes from onnx.shape_inference import infer_shapes
...@@ -23,14 +25,16 @@ logger = logging.getLogger(__name__) ...@@ -23,14 +25,16 @@ logger = logging.getLogger(__name__)
__all__ = [ __all__ = [
'print_pb_structure', 'print_pb_structure',
'build_value_refs', 'build_value_refs',
'tensor_dtype',
'tensor_shape',
'node_attrs', 'node_attrs',
'node_topo', 'node_topo',
'node_iter', 'node_iter',
'tensor_dtype',
'tensor_shape',
'graph_ops', 'graph_ops',
'graph_weights', 'graph_weights',
'inferred_model_value_info', 'inferred_model_value_info',
'polish_model',
'polish_and_save',
'optimize_model_skip_op_for_inference', 'optimize_model_skip_op_for_inference',
'optimize_model_strip_initializer', 'optimize_model_strip_initializer',
'optimize_model_cast', 'optimize_model_cast',
...@@ -50,17 +54,17 @@ def print_pb_structure(message, loop_iterative=False, depth=0): ...@@ -50,17 +54,17 @@ def print_pb_structure(message, loop_iterative=False, depth=0):
if hasattr(message, 'DESCRIPTOR') and hasattr(message.DESCRIPTOR, 'fields'): if hasattr(message, 'DESCRIPTOR') and hasattr(message.DESCRIPTOR, 'fields'):
for field in message.DESCRIPTOR.fields: for field in message.DESCRIPTOR.fields:
print('\t' * depth + '-', field.name) print('\t' * depth + '-', field.name)
print_pb_structure( print_pb_structure(getattr(message, field.name),
getattr(message, field.name), loop_iterative=loop_iterative,
loop_iterative=loop_iterative, depth=(depth + 1))
depth=(depth + 1))
if loop_iterative and hasattr(message, 'MergeFrom') and hasattr( if loop_iterative and hasattr(message, 'MergeFrom') and hasattr(
message, '__len__'): message, '__len__'):
for idx, item in enumerate(message): for idx, item in enumerate(message):
print('\t' * depth + '-', idx) print('\t' * depth + '-', idx)
print_pb_structure( print_pb_structure(item,
item, loop_iterative=loop_iterative, depth=(depth + 1)) loop_iterative=loop_iterative,
depth=(depth + 1))
def build_value_refs(nodes): def build_value_refs(nodes):
...@@ -83,14 +87,21 @@ def get_attribute_value2(attr): ...@@ -83,14 +87,21 @@ def get_attribute_value2(attr):
get_attribute_value enhanced get_attribute_value enhanced
""" """
assert isinstance(
attr, onnx.AttributeProto), 'attr is not a AttributeProto instance'
if attr.type == onnx.AttributeProto.TENSOR: if attr.type == onnx.AttributeProto.TENSOR:
dtype = np.dtype(TENSOR_TYPE_TO_NP_TYPE[attr.t.data_type]) dtype = np.dtype(TENSOR_TYPE_TO_NP_TYPE[attr.t.data_type])
data = attr.t.raw_data data = attr.t.raw_data
value = np.frombuffer( value = np.frombuffer(data,
data, dtype=dtype, count=(len(data) // dtype.itemsize)) dtype=dtype,
count=(len(data) // dtype.itemsize))
elif attr.type == onnx.AttributeProto.STRING: elif attr.type == onnx.AttributeProto.STRING:
value = attr.s value = attr.s
value = value.decode() if isinstance(value, bytes) else value value = value.decode() if isinstance(value, bytes) else value
elif attr.type == onnx.AttributeProto.STRINGS:
value = attr.strings
value = [s.decode() if isinstance(s, bytes) else s for s in value]
else: else:
value = get_attribute_value(attr) value = get_attribute_value(attr)
return value return value
...@@ -101,6 +112,9 @@ def tensor_dtype(tensor): ...@@ -101,6 +112,9 @@ def tensor_dtype(tensor):
get ONNX tensor in np.dtype get ONNX tensor in np.dtype
""" """
assert isinstance(
tensor, onnx.ValueInfoProto), 'tensor is not a ValueInfoProto instance'
return TENSOR_TYPE_TO_NP_TYPE[tensor.type.tensor_type.elem_type] return TENSOR_TYPE_TO_NP_TYPE[tensor.type.tensor_type.elem_type]
...@@ -109,7 +123,10 @@ def tensor_shape(tensor): ...@@ -109,7 +123,10 @@ def tensor_shape(tensor):
get ONNX tensor shape get ONNX tensor shape
""" """
return [dim.dim_value for dim in tensor.type.tensor_type.shape.dim] assert isinstance(
tensor, onnx.ValueInfoProto), 'tensor is not a ValueInfoProto instance'
return tuple([dim.dim_value for dim in tensor.type.tensor_type.shape.dim])
def node_attrs(node): def node_attrs(node):
...@@ -117,6 +134,8 @@ def node_attrs(node): ...@@ -117,6 +134,8 @@ def node_attrs(node):
convert ONNX node attributes to dict convert ONNX node attributes to dict
""" """
assert isinstance(node, onnx.NodeProto), 'node is not a NodeProto instance'
return {attr.name: get_attribute_value2(attr) return {attr.name: get_attribute_value2(attr)
for attr in node.attribute} # dict for attr in node.attribute} # dict
...@@ -145,12 +164,12 @@ def node_topo(nodes, topo='default'): ...@@ -145,12 +164,12 @@ def node_topo(nodes, topo='default'):
for node_idx, degree in enumerate(node_in_degrees): for node_idx, degree in enumerate(node_in_degrees):
if degree == 0: if degree == 0:
queue.append(node_idx) queue.append(node_idx)
while len(queue) > 0: while queue:
node_idx = queue.pop(0) node_idx = queue.pop(0)
node_topo.append(node_idx) node_topo.append(node_idx)
for val_name in nodes[node_idx].output: for val_name in nodes[node_idx].output:
output_refs[val_name].remove(node_idx) output_refs[val_name].remove(node_idx)
if len(output_refs[val_name]) > 0: if output_refs[val_name]:
continue continue
output_refs.pop(val_name) output_refs.pop(val_name)
if val_name not in input_refs: if val_name not in input_refs:
...@@ -170,12 +189,12 @@ def node_topo(nodes, topo='default'): ...@@ -170,12 +189,12 @@ def node_topo(nodes, topo='default'):
for node_idx, degree in enumerate(node_out_degrees): for node_idx, degree in enumerate(node_out_degrees):
if degree == 0: if degree == 0:
queue.append(node_idx) queue.append(node_idx)
while len(queue) > 0: while queue:
node_idx = queue.pop(0) node_idx = queue.pop(0)
node_topo.append(node_idx) node_topo.append(node_idx)
for val_name in nodes[node_idx].input: for val_name in nodes[node_idx].input:
input_refs[val_name].remove(node_idx) input_refs[val_name].remove(node_idx)
if len(input_refs[val_name]) > 0: if input_refs[val_name]:
continue continue
input_refs.pop(val_name) input_refs.pop(val_name)
if val_name not in output_refs: if val_name not in output_refs:
...@@ -208,6 +227,11 @@ def node_iter(nodes, indices=None): ...@@ -208,6 +227,11 @@ def node_iter(nodes, indices=None):
if name == '': if name == '':
name = 'op_' + str(index) name = 'op_' + str(index)
# else: # make_op_name
# for s in ' \\|/:-': #
# name = name.replace(s, '_')
if domain == '': if domain == '':
domain = DEFAULT_OP_DOMAIN domain = DEFAULT_OP_DOMAIN
...@@ -219,9 +243,8 @@ def graph_ops(graph, topo='default'): ...@@ -219,9 +243,8 @@ def graph_ops(graph, topo='default'):
generator for ONNX node graph with given topology generator for ONNX node graph with given topology
""" """
if not isinstance(graph, onnx.GraphProto): assert isinstance(graph,
logger.error('graph is not a GraphProto instance') onnx.GraphProto), 'graph is not a GraphProto instance'
return
return node_iter(graph.node, node_topo(graph.node, topo)) return node_iter(graph.node, node_topo(graph.node, topo))
...@@ -231,9 +254,8 @@ def graph_weights(graph): ...@@ -231,9 +254,8 @@ def graph_weights(graph):
generator for weights of an ONNX model generator for weights of an ONNX model
""" """
if not isinstance(graph, onnx.GraphProto): assert isinstance(graph,
logger.error('graph is not a GraphProto instance') onnx.GraphProto), 'graph is not a GraphProto instance'
return
for initializer in graph.initializer: for initializer in graph.initializer:
name = initializer.name name = initializer.name
...@@ -246,29 +268,32 @@ def inferred_model_value_info(model): ...@@ -246,29 +268,32 @@ def inferred_model_value_info(model):
collect value/type info for an ONNX model collect value/type info for an ONNX model
""" """
assert isinstance(model,
onnx.ModelProto), 'model is not a ModelProto instance'
model = infer_shapes(model) model = infer_shapes(model)
graph = model.graph graph = model.graph
value_info = Dict() value_info = Dict()
for item in graph.value_info: for item in graph.value_info:
value_info[item.name] = dict( value_info[item.name] = {
dtype=tensor_dtype(item), 'dtype': tensor_dtype(item),
shape=tensor_shape(item), 'shape': tensor_shape(item),
external=False, 'external': False,
) }
for item in graph.input: for item in graph.input:
assert item.name not in value_info assert item.name not in value_info
value_info[item.name] = dict( value_info[item.name] = {
dtype=tensor_dtype(item), 'dtype': tensor_dtype(item),
shape=tensor_shape(item), 'shape': tensor_shape(item),
external=True, 'external': True,
) }
for item in graph.output: for item in graph.output:
# assert item.name not in value_info, 'bypass-model not supported' # assert item.name not in value_info, 'bypass-model not supported'
value_info[item.name] = dict( value_info[item.name] = {
dtype=tensor_dtype(item), 'dtype': tensor_dtype(item),
shape=tensor_shape(item), 'shape': tensor_shape(item),
external=True, 'external': True,
) }
return value_info return value_info
...@@ -302,12 +327,63 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs): ...@@ -302,12 +327,63 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs):
return processed return processed
def polish_model(model, internals=True, extras=True, checking=True):
"""
polish_model enhanced for inference
"""
if checking:
check_model(model)
strip_doc_string(model)
if internals:
passes = optimizer.get_available_passes()
passes = list(filter(lambda name: not name.startswith('split_'),
passes)) #
logger.debug('builtin optimizations to perform in ONNX:\n\t%s', passes)
model = optimizer.optimize(model, passes=passes)
if extras:
for optimize in (
optimize_model_skip_op_for_inference,
optimize_model_strip_initializer,
optimize_model_cast,
optimize_model_slice,
):
model = optimize(model)
model = infer_shapes(model)
if checking:
check_model(model)
return model
def polish_and_save(model_filename,
suffix='.polished',
save_filename=None,
*args,
**kwargs):
"""
run polish_model and save
"""
if save_filename is None:
save_filename = model_filename.replace('.onnx', suffix + '.onnx')
model = onnx.load(model_filename)
model = polish_model(model, *args, **kwargs)
onnx.save(model, save_filename)
logger.info('polished model saved to: %s', save_filename)
return save_filename
def optimize_model_skip_op_for_inference(model, op_list=None): def optimize_model_skip_op_for_inference(model, op_list=None):
""" """
skip ops can be bypassed for inference skip ops can be bypassed for inference
""" """
assert isinstance(model,
onnx.ModelProto), 'model is not a ModelProto instance'
if op_list is None: if op_list is None:
op_list = ['Dropout'] op_list = ('Dropout', 'Identity')
nodes = model.graph.node nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes) input_refs, output_refs = build_value_refs(nodes)
...@@ -322,10 +398,10 @@ def optimize_model_skip_op_for_inference(model, op_list=None): ...@@ -322,10 +398,10 @@ def optimize_model_skip_op_for_inference(model, op_list=None):
if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''): if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
continue continue
op_type = node.op_type op_type = node.op_type
if not (op_type in op_list): if op_type not in op_list:
continue continue
if op_type in ['Dropout']: if op_type in ('Dropout', ):
input_name = node.input[0] input_name = node.input[0]
output_name = node.output[0] output_name = node.output[0]
elif not (len(node.input) == 1 and len(node.output) == 1): elif not (len(node.input) == 1 and len(node.output) == 1):
...@@ -368,6 +444,9 @@ def optimize_model_strip_initializer(model, keep_input_only=True): ...@@ -368,6 +444,9 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
strip weights for inference strip weights for inference
""" """
assert isinstance(model,
onnx.ModelProto), 'model is not a ModelProto instance'
nodes = model.graph.node nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes) input_refs, output_refs = build_value_refs(nodes)
out_names = [val.name for val in model.graph.output] out_names = [val.name for val in model.graph.output]
...@@ -406,9 +485,12 @@ def optimize_model_strip_initializer(model, keep_input_only=True): ...@@ -406,9 +485,12 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
def optimize_model_cast(model): def optimize_model_cast(model):
""" """
strip cascade and unecessary onnx::Cast strip cascade and unecessary onnx::Cast-9:
""" """
assert isinstance(model,
onnx.ModelProto), 'model is not a ModelProto instance'
nodes = model.graph.node nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes) input_refs, output_refs = build_value_refs(nodes)
value_info = inferred_model_value_info(model) value_info = inferred_model_value_info(model)
...@@ -422,7 +504,7 @@ def optimize_model_cast(model): ...@@ -422,7 +504,7 @@ def optimize_model_cast(model):
for node_idx, node in enumerate(nodes): for node_idx, node in enumerate(nodes):
if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''): if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
continue continue
if not (node.op_type == 'Cast'): if node.op_type != 'Cast':
continue continue
attrs = node_attrs(node) attrs = node_attrs(node)
output_dtype = TENSOR_TYPE_TO_NP_TYPE[attrs['to']] output_dtype = TENSOR_TYPE_TO_NP_TYPE[attrs['to']]
...@@ -463,19 +545,22 @@ def optimize_model_cast(model): ...@@ -463,19 +545,22 @@ def optimize_model_cast(model):
def optimize_model_slice(model): def optimize_model_slice(model):
""" """
strip cascade and unecessary onnx::Slice strip cascade and unecessary onnx::Slice-1:9
""" """
assert isinstance(model,
onnx.ModelProto), 'model is not a ModelProto instance'
nodes = model.graph.node nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes) input_refs, output_refs = build_value_refs(nodes)
def _build_slice_node_chain(node_idx): def build_slice_node_chain(node_idx):
chain = [] chain = []
while True: while True:
node = nodes[node_idx] node = nodes[node_idx]
if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''): if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
return chain return chain
if not node.op_type == 'Slice': if node.op_type != 'Slice':
return chain return chain
chain.append(node_idx) chain.append(node_idx)
output_name = node.output[0] output_name = node.output[0]
...@@ -485,7 +570,7 @@ def optimize_model_slice(model): ...@@ -485,7 +570,7 @@ def optimize_model_slice(model):
node_idx = list(input_refs[output_name])[0] node_idx = list(input_refs[output_name])[0]
# axis: (start, end) # axis: (start, end)
def _merge_slice(slice_chain): def merge_slice(slice_chain):
merged_slice = dict() merged_slice = dict()
for slice_node_idx in slice_chain: for slice_node_idx in slice_chain:
node = nodes[slice_node_idx] node = nodes[slice_node_idx]
...@@ -508,14 +593,14 @@ def optimize_model_slice(model): ...@@ -508,14 +593,14 @@ def optimize_model_slice(model):
ret_nodes = ret.graph.node ret_nodes = ret.graph.node
nodes_to_remove = [] nodes_to_remove = []
for node_idx in range(len(nodes)): for node_idx in range(len(nodes)):
slice_chain = _build_slice_node_chain(node_idx) slice_chain = build_slice_node_chain(node_idx)
if len(slice_chain) == 0: if not slice_chain:
continue continue
merged_slice = _merge_slice(slice_chain) merged_slice = merge_slice(slice_chain)
if len(merged_slice) > 0 and len(slice_chain) == 1: # no need to merge if merged_slice and len(slice_chain) == 1: # no need to merge
continue continue
attrs = dict(axes=[], starts=[], ends=[]) attrs = {'axes': [], 'starts': [], 'ends': []}
for axis, (start, end) in merged_slice.items(): for axis, (start, end) in merged_slice.items():
attrs['axes'].append(axis) attrs['axes'].append(axis)
attrs['starts'].append(start) attrs['starts'].append(start)
...@@ -526,12 +611,11 @@ def optimize_model_slice(model): ...@@ -526,12 +611,11 @@ def optimize_model_slice(model):
output_name = last_node.output[0] output_name = last_node.output[0]
processed = -1 processed = -1
if output_name in input_refs: # 0, [1...] if output_name in input_refs: # 0, [1...]
new_input_name = first_node.output[0] if len( new_input_name = first_node.output[0] if merged_slice else input_name
merged_slice) > 0 else input_name
processed = skip_node_forward(ret_nodes, output_name, processed = skip_node_forward(ret_nodes, output_name,
new_input_name, input_refs) new_input_name, input_refs)
if processed > 0: if processed > 0:
if len(merged_slice) > 0: if merged_slice:
remain_idx = slice_chain[0] remain_idx = slice_chain[0]
remove_chain = slice_chain[1:] remove_chain = slice_chain[1:]
slice_node = ret_nodes[remain_idx] slice_node = ret_nodes[remain_idx]
...@@ -545,12 +629,11 @@ def optimize_model_slice(model): ...@@ -545,12 +629,11 @@ def optimize_model_slice(model):
remove_chain = slice_chain remove_chain = slice_chain
if processed < 0 and input_name in output_refs: if processed < 0 and input_name in output_refs:
new_output_name = last_node.input[0] if len( new_output_name = last_node.input[0] if merged_slice else output_name
merged_slice) > 0 else output_name
processed = skip_node_backward(ret_nodes, input_name, processed = skip_node_backward(ret_nodes, input_name,
new_output_name, output_refs) new_output_name, output_refs)
if processed > 0: if processed > 0:
if len(merged_slice) > 0: if merged_slice:
remain_idx = slice_chain[-1] remain_idx = slice_chain[-1]
remove_chain = slice_chain[:-1] remove_chain = slice_chain[:-1]
slice_node = ret_nodes[remain_idx] slice_node = ret_nodes[remain_idx]
...@@ -565,7 +648,7 @@ def optimize_model_slice(model): ...@@ -565,7 +648,7 @@ def optimize_model_slice(model):
if processed > 0: if processed > 0:
nodes_to_remove.extend(remove_chain) nodes_to_remove.extend(remove_chain)
if len(merged_slice) == 0: if not merged_slice:
logger.debug('skip slice chain %s -> %s -> %s', input_name, logger.debug('skip slice chain %s -> %s -> %s', input_name,
slice_chain, output_name) slice_chain, output_name)
elif processed < 0: # NEVERFIX: not merge standalone slice chain elif processed < 0: # NEVERFIX: not merge standalone slice chain
...@@ -586,22 +669,16 @@ if __name__ == '__main__': ...@@ -586,22 +669,16 @@ if __name__ == '__main__':
level=logging.DEBUG, level=logging.DEBUG,
) )
from onnx.checker import check_model
from onnx.utils import polish_model
from onnx.version_converter import convert_version from onnx.version_converter import convert_version
model = onnx.load('../examples/t1.onnx') model = onnx.load('/tmp/export.onnx')
print_pb_structure(model, loop_iterative=False) print_pb_structure(model, loop_iterative=False)
check_model(model) check_model(model)
model = convert_version(model, 9) model = convert_version(model, 9)
model = optimize_model_skip_op_for_inference(model)
model = optimize_model_strip_initializer(model)
model = optimize_model_cast(model)
model = optimize_model_slice(model)
model = polish_model(model) model = polish_model(model)
onnx.save(model, '/tmp/optimized.onnx') onnx.save(model, '/tmp/export.polished.onnx')
graph = model.graph graph = model.graph
value_info = inferred_model_value_info(model) value_info = inferred_model_value_info(model)
...@@ -613,23 +690,23 @@ if __name__ == '__main__': ...@@ -613,23 +690,23 @@ if __name__ == '__main__':
logger.info('ops:') logger.info('ops:')
for name, domain, op_type, _, _, attrs in graph_ops(graph, topo='forward'): for name, domain, op_type, _, _, attrs in graph_ops(graph, topo='forward'):
logger.info('%s %s::%s: %s', name, domain, op_type, attrs) logger.info('- \t%s %s::%s: %s', name, domain, op_type, attrs)
logger.info('weights:') logger.info('weights:')
for name, array in graph_weights(graph): for name, array in graph_weights(graph):
weights.append(name) weights.append(name)
logger.info('%s: %s', name, array.shape) logger.info('- \t%s: %s', name, array.shape)
logger.info('inputs:') logger.info('inputs:')
external_inputs = [] external_inputs = []
for name in inputs: for name in inputs:
if name not in weights: if name not in weights:
external_inputs.append(name) external_inputs.append(name)
logger.info('%s: %s', name, value_info[name]['shape']) logger.info('- \t%s: %s', name, value_info[name]['shape'])
logger.info('outputs:') logger.info('outputs:')
external_outputs = [] external_outputs = []
for name in outputs: for name in outputs:
if name not in weights: if name not in weights:
external_outputs.append(name) external_outputs.append(name)
logger.info('%s: %s', name, value_info[name]['shape']) logger.info('- \t%s: %s', name, value_info[name]['shape'])
此差异已折叠。
...@@ -6,115 +6,180 @@ Created on Fri Mar 22 11:22:46 2019 ...@@ -6,115 +6,180 @@ Created on Fri Mar 22 11:22:46 2019
@author: Macrobull @author: Macrobull
""" """
from __future__ import division
import logging
import numpy as np import numpy as np
import torch import torch
from collections import OrderedDict as Dict from collections import OrderedDict
from typing import (
TypeVar,
Any,
Generic,
Iterable,
List,
Mapping,
Optional,
Sequence,
Text,
Tuple,
Union,
)
logger = logging.getLogger(__name__)
__all__ = [
'export_data',
'export_onnx_with_validation',
]
my_dict = OrderedDict
KT = TypeVar('KT')
VT = TypeVar('VT')
def _ensure_list(obj): class MyDict(my_dict, Generic[KT, VT]):
if isinstance(obj, (list, set, tuple)): pass
def ensure_list(obj: Union[object, Sequence[object]]) -> List[object]:
if isinstance(obj, (list, tuple, set)):
return list(obj) return list(obj)
return [obj] return [obj]
def _ensure_tuple(obj): def ensure_tuple(obj: Union[object, Sequence[object]]) -> Tuple[object, ...]:
if isinstance(obj, (list, set, tuple)): if isinstance(obj, (tuple, list, set)):
return tuple(obj) return tuple(obj)
return (obj, ) return (obj, )
def _flatten_list(obj, out=None): def flatten_list(obj: List[Union[object, List[object]]],
assert isinstance(obj, list) out: Optional[List[object]] = None) -> List[object]:
assert isinstance(obj, list), 'list type required'
if out is None: if out is None:
out = type(obj)() out = type(obj)()
for item in obj: for item in obj:
if isinstance(item, list): if isinstance(item, list):
_flatten_list(item, out) flatten_list(item, out)
else: else:
out.append(item) out.append(item)
return out return out
def export_data(state_dict, prefix=''): def export_data(state_dict: Mapping[Text, Any], prefix: Text = '') -> None:
""" """
export binary data with meta text for raw C++ inference engines export binary data with meta text for raw C++ inference engines
""" """
def _str(obj): def str_(obj: object) -> Text:
if isinstance(obj, (tuple, list)): if isinstance(obj, (tuple, list, set)):
return str(obj)[1:-1].replace(' ', '') return str(obj)[1:-1].replace(' ', '')
return str(obj) return str(obj)
prefix_ = prefix + ('_' if prefix else '') prefix_ = prefix + ('_' if prefix else '')
fp = open('{}.txt'.format(prefix if prefix else 'meta'), 'w') fp = open('{}.txt'.format(prefix or 'meta'), mode='w')
for key, value in state_dict.items(): for key, value in state_dict.items():
data = None data = None
if torch and torch.is_tensor(value): if torch.is_tensor(value):
data = value.data.cpu().numpy() data = value.data.cpu().numpy()
elif np and isinstance(value, np.ndarray): elif isinstance(value, np.ndarray):
data = value data = value
if data is not None: if data is not None:
data.tofile('{}{}.bin'.format(prefix_, key)) data.tofile('{}{}.bin'.format(prefix_, key))
fp.write('{}.dtype={}\n'.format(key, _str(data.dtype.name))) fp.write('{}.dtype={}\n'.format(key, str_(data.dtype.name)))
fp.write('{}.shape={}\n'.format(key, _str(data.shape))) fp.write('{}.shape={}\n'.format(key, str_(data.shape)))
else: else:
fp.write('{}={}\n'.format(key, _str(value))) fp.write('{}={}\n'.format(key, str_(value)))
fp.close() fp.close()
def export_onnx_with_validation(model, def export_onnx_with_validation(
inputs, model: torch.nn.Module, # or JITScriptModule
export_basepath, inputs: Sequence[Union[torch.Tensor, Sequence[object]]],
input_names=None, export_basepath: Text,
output_names=None, input_names: Optional[List[Text]] = None,
use_npz=True, output_names: Optional[List[Text]] = None,
*args, use_npz: bool = True,
**kwargs): *args,
**kwargs) -> Sequence[Union[torch.Tensor, Sequence[object]]]:
""" """
export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
""" """
is_list_or_tuple = lambda x: isinstance(x, (list, tuple)) is_tuple_or_list = lambda x: isinstance(x, (tuple, list))
def _tensors_to_arrays(tensors): def tensors_to_arrays(tensors: Union[torch.Tensor, Iterable[
Union[torch.Tensor, Iterable[Any]]]], ) -> List[np.ndarray]:
if torch.is_tensor(tensors): if torch.is_tensor(tensors):
return tensors.data.cpu().numpy() return tensors.data.cpu().numpy()
arrays = [] return list(map(tensors_to_arrays, tensors))
for tensor in tensors:
arrays.append(_tensors_to_arrays(tensor)) def zip_dict(
return arrays keys: Optional[Iterable[Any]],
values: Sequence[Union[Any, Sequence[Any]]],
def _zip_dict(keys, values): ) -> MyDict[Text, Union[object, MyDict[Text, object]]]:
ret = Dict() keys = keys or range(len(values))
ret = my_dict()
for idx, (key, value) in enumerate(zip(keys, values)): for idx, (key, value) in enumerate(zip(keys, values)):
is_key_list = is_list_or_tuple(key) is_key_list = is_tuple_or_list(key)
is_value_list = is_list_or_tuple(value) is_value_list = is_tuple_or_list(value)
assert is_key_list == is_value_list, 'keys and values mismatch' assert is_key_list == is_value_list, 'keys and values mismatch'
if is_value_list: if is_value_list:
ret[str(idx)] = _zip_dict(key, value) ret[str(idx)] = zip_dict(key, value)
else: else:
ret[key] = value ret[key] = value
return ret return ret
torch_inputs = _ensure_tuple(inputs) # WORKAROUND: for torch.onnx torch_inputs = ensure_tuple(inputs) # WORKAROUND: for torch.onnx
outputs = torch.onnx.export( outputs = torch.onnx.export(model,
model, torch_inputs,
torch_inputs, export_basepath + '.onnx',
export_basepath + '.onnx', input_names=(None if input_names is None else
input_names=_flatten_list(input_names), flatten_list(input_names)),
output_names=_flatten_list(output_names), output_names=(None if output_names is None else
*args, flatten_list(output_names)),
**kwargs) *args,
**kwargs)
if outputs is None: # WORKAROUND: for torch.onnx if outputs is None: # WORKAROUND: for torch.onnx
outputs = model(*inputs) training = kwargs.get('training', False)
torch_outputs = _ensure_tuple(outputs) with torch.onnx.set_training(model, training):
outputs = model(*inputs)
torch_outputs = ensure_tuple(outputs)
inputs = _zip_dict(input_names, _tensors_to_arrays(torch_inputs)) inputs = zip_dict(input_names, tensors_to_arrays(torch_inputs))
outputs = _zip_dict(output_names, _tensors_to_arrays(torch_outputs)) outputs = zip_dict(output_names, tensors_to_arrays(torch_outputs))
if use_npz: if use_npz:
np.savez(export_basepath + '.npz', inputs=inputs, outputs=outputs) np.savez(
export_basepath + '.npz',
inputs=inputs,
outputs=outputs,
)
else: else:
np.save(export_basepath + '.npy', np.save(export_basepath + '.npy',
np.array(Dict(inputs=inputs, outputs=outputs))) np.asarray(my_dict(inputs=inputs, outputs=outputs)),
allow_pickle=True)
return torch_outputs return torch_outputs
if __name__ == '__main__':
from torchvision.models import resnet18 as net
model = net()
xb = torch.rand((1, 3, 224, 224))
export_onnx_with_validation(
model,
(xb, ),
'/tmp/export',
input_names=[
'image',
],
output_names=[
'prob',
],
use_npz=True,
)
...@@ -8,38 +8,85 @@ Created on Fri Mar 22 12:17:19 2019 ...@@ -8,38 +8,85 @@ Created on Fri Mar 22 12:17:19 2019
import importlib, logging, os, sys import importlib, logging, os, sys
logger = logging.getLogger(__name__)
__all__ = [
'fluid_prog_shape_infer',
'validate',
]
def flatten_dict(obj, out=None):
assert isinstance(obj, dict), 'dict type required'
def _flatten_dict(obj, out=None):
assert isinstance(obj, dict)
if out is None: if out is None:
out = type(obj)() out = type(obj)()
for key, value in obj.items(): for key, value in obj.items():
if isinstance(value, dict): if isinstance(value, dict):
_flatten_dict(value, out) flatten_dict(value, out)
else: else:
assert key not in out assert key not in out, 'key conflicted'
out[key] = value out[key] = value
return out return out
def _ensure_list(obj): def ensure_list(obj):
for cls in [list, set, tuple]: if isinstance(obj, (list, tuple, set)):
if isinstance(obj, cls): return list(obj)
return list(obj)
return [obj] return [obj]
def fluid_prog_shape_infer(prog):
"""
additional type-shape inference for fluid program
"""
import paddle.fluid as fluid
assert isinstance(prog,
fluid.framework.Program), 'prog is not a Program instance'
logger.info('performing type-shape inference ...')
for block in prog.blocks:
block_desc = block.desc
for idx_op in range(block_desc.op_size()):
op_desc = block_desc.op(idx_op)
if op_desc.type() in ('feed', 'fetch'):
continue
op_desc.infer_var_type(block_desc)
op_desc.infer_shape(block_desc)
for var_name, var in block.vars.items():
var_desc = var.desc
if var_desc.type() != fluid.core.VarDesc.VarType.LOD_TENSOR:
continue
# WORKAROUND: dirty way to give dtype to partial-infered vars
# which could not be cleared!
try:
var.to_string(True)
except ValueError:
var_desc.set_dtype(fluid.core.VarDesc.VarType.FP32)
logger.debug('dtype of var %s not inferred, float32 assumed',
var_name)
def validate(fluid_model_filename, def validate(fluid_model_filename,
golden_data_filename, golden_data_filename='',
model_func_name='inference',
atol=1e-3, atol=1e-3,
rtol=1e-4, rtol=1e-3,
model_func_name='inference',
save_inference_model=False, save_inference_model=False,
inference_input_names=None,
**kwargs): **kwargs):
""" """
inference the converted Paddle fluid model, validate with given golden data inference the converted Paddle fluid model, validate with given golden data
""" """
assert isinstance(fluid_model_filename, str)
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -52,12 +99,12 @@ def validate(fluid_model_filename, ...@@ -52,12 +99,12 @@ def validate(fluid_model_filename,
# load model # load model
fluid_model_dir, basename = os.path.split(fluid_model_filename) fluid_model_dir, basename = os.path.split(fluid_model_filename)
if basename == '__model__': # is desc program if basename == '__model__': # is desc program
logger.debug('using desc file %s', basename) logger.info('using desc file %s', basename)
prog, _, var_outs = fluid.io.load_inference_model(fluid_model_dir, exe) prog, _, var_outs = fluid.io.load_inference_model(fluid_model_dir, exe)
out_names = var_outs # HINT: pass var if fetch ops already created out_names = var_outs # HINT: pass var if fetch ops already created
logger.info('model load passed') logger.info('model load passed')
elif basename.endswith('.py'): # is python code elif basename.endswith('.py'): # is Python code
logger.debug('using python code file %s', basename) logger.info('using code file %s', basename)
module_name, _ = os.path.splitext(basename) module_name, _ = os.path.splitext(basename)
sys_path = sys.path.copy() sys_path = sys.path.copy()
sys.path.append(fluid_model_dir) sys.path.append(fluid_model_dir)
...@@ -73,74 +120,92 @@ def validate(fluid_model_filename, ...@@ -73,74 +120,92 @@ def validate(fluid_model_filename,
func) func)
var_outs = func() var_outs = func()
var_outs = _ensure_list(var_outs) var_outs = ensure_list(var_outs)
out_names = [var.name for var in var_outs out_names = [var.name for var in var_outs
] # HINT: pass string to create fetch ops ] # HINT: pass string to create fetch ops
logger.info('import passed') logger.info('import passed')
prog = fluid.default_main_program() prog = fluid.default_main_program()
fluid.io.load_persistables( fluid.io.load_persistables(executor=exe,
executor=exe, dirname=fluid_model_dir, main_program=prog) dirname=fluid_model_dir,
main_program=prog)
logger.info('weight load passed') logger.info('weight load passed')
else: else:
raise ValueError('unsupported Paddle fluid model filename') raise ValueError('unsupported Paddle fluid model filename')
# load data # load data
logger.info('using golden data %s', golden_data_filename) if golden_data_filename:
if golden_data_filename.endswith('.npz'): logger.info('using golden data %s', golden_data_filename)
test_data = np.load(golden_data_filename, encoding='bytes') if golden_data_filename.endswith('.npz'):
input_data = test_data['inputs'].tolist() test_data = np.load(
output_data = test_data['outputs'].tolist() golden_data_filename,
else: encoding='bytes',
test_data = np.load(golden_data_filename, encoding='bytes').tolist() allow_pickle=True,
input_data = test_data['inputs'] )
output_data = test_data['outputs'] input_data = test_data['inputs'].tolist()
input_data = _flatten_dict(input_data) output_data = test_data['outputs'].tolist()
output_data = _flatten_dict(output_data) else:
logger.info('found %d I/O golden data, starting test ...', test_data = np.load(
len(input_data) + len(output_data)) golden_data_filename,
encoding='bytes',
# DEBUG: reload test for python code allow_pickle=True,
if basename.endswith('.py') and save_inference_model: ).tolist()
fluid.io.save_inference_model( input_data = test_data['inputs']
fluid_model_dir, output_data = test_data['outputs']
input_data.keys(),
var_outs, input_data = flatten_dict(input_data)
exe, output_data = flatten_dict(output_data)
main_program=prog, input_names = input_data.keys()
export_for_deployment=True) # output_names = output_data.keys()
logger.info('with %d inputs and %d outputs', len(input_data),
len(output_data))
elif save_inference_model:
assert inference_input_names is not None, (
'input names required for type-shape inference')
input_names = inference_input_names
logger.info('using input names: %s', ', '.join(input_names))
# type-shape inference and re-save
if save_inference_model:
fluid_prog_shape_infer(prog)
fluid.io.save_inference_model(fluid_model_dir,
input_names,
var_outs,
exe,
main_program=prog,
export_for_deployment=True)
logger.info('model re-save passed') logger.info('model re-save passed')
fluid.io.load_inference_model(fluid_model_dir, exe) fluid.io.load_inference_model(fluid_model_dir, exe)
logger.info('model re-load passed') logger.info('model re-load passed')
if golden_data_filename == '':
return True
# execute # execute
outputs = exe.run(prog, feed=input_data, fetch_list=out_names) outputs = exe.run(prog, feed=input_data,
fetch_list=out_names) # out_names can be vars
logger.info('execution passed') logger.info('execution passed')
# validate # validate
passed = True passed = True
for (name, truth), output in zip(output_data.items(), outputs): for (name, truth), output in zip(output_data.items(), outputs):
logger.info('testing output {} ...'.format(name)) logger.info('testing on output {} ...'.format(name))
try: try:
np.testing.assert_allclose( np.testing.assert_allclose(output,
output, truth,
truth, rtol=rtol,
rtol=rtol, atol=atol,
atol=atol, equal_nan=False,
equal_nan=False, verbose=True)
verbose=True)
except AssertionError as e: except AssertionError as e:
passed = False passed = False
logger.error('failed: %s\n', e) logger.error('failed: %s\n', e)
if passed: logger.info('accuracy %spassed', '' if passed else 'not ')
logger.info('accuracy passed')
else:
logger.info('accuracy not passed')
return passed return passed
if __name__ == '__main__': def main():
import argparse import argparse
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -162,6 +227,7 @@ if __name__ == '__main__': ...@@ -162,6 +227,7 @@ if __name__ == '__main__':
'--test_data', '--test_data',
'-t', '-t',
type=str, type=str,
default='',
help='I/O golden data for validation, e.g. test.npy, test.npz', help='I/O golden data for validation, e.g. test.npy, test.npz',
) )
parser.add_argument( parser.add_argument(
...@@ -174,23 +240,39 @@ if __name__ == '__main__': ...@@ -174,23 +240,39 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
'--rtol', '--rtol',
type=float, type=float,
default=1e-4, default=1e-2,
help='assertion relative tolerance for validation', help='assertion relative tolerance for validation',
) )
parser.add_argument(
'--infer_inputs',
'-i',
nargs='?',
default=None,
const='',
help=
'perform type-shape inference with given input names and re-save model',
)
args = parser.parse_args() args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s' logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_level = logging.DEBUG if args.debug else logging.INFO logging_level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(format=logging_format, level=logging_level) logging.basicConfig(format=logging_format, level=logging_level)
debug = args.debug # debug = args.debug
fluid_model_filename = args.model[0] fluid_model_filename = args.model[0]
golden_data_filename = args.test_data golden_data_filename = args.test_data
atol, rtol = args.atol, args.rtol atol, rtol = args.atol, args.rtol
save_inference_model = args.infer_inputs is not None
inference_input_names = args.infer_inputs.split(
',') if args.infer_inputs else None
validate(fluid_model_filename,
golden_data_filename=golden_data_filename,
atol=atol,
rtol=rtol,
save_inference_model=save_inference_model,
inference_input_names=inference_input_names)
validate(
fluid_model_filename, if __name__ == '__main__':
golden_data_filename, main()
atol=atol,
rtol=rtol,
save_inference_model=debug)
...@@ -11,10 +11,11 @@ from __future__ import division ...@@ -11,10 +11,11 @@ from __future__ import division
import logging, os import logging, os
import numpy as np import numpy as np
from collections import OrderedDict as Dict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from . import symbolic from . import symbolic
from .symbolic import _make_var_name as make_var_name
try: try:
import paddle.fluid.proto.framework_pb2 as framework_pb2 import paddle.fluid.proto.framework_pb2 as framework_pb2
...@@ -30,7 +31,7 @@ __all__ = [ ...@@ -30,7 +31,7 @@ __all__ = [
] ]
def _irepr(obj, to='_'): def irepr(obj, to='_'):
"""inline repr""" """inline repr"""
s = repr(obj) s = repr(obj)
...@@ -41,12 +42,14 @@ def _irepr(obj, to='_'): ...@@ -41,12 +42,14 @@ def _irepr(obj, to='_'):
return s return s
def _flatten_list(obj, out=None): def flatten_list(obj, out=None):
assert isinstance(obj, list), 'list type required'
if out is None: if out is None:
out = type(obj)() out = type(obj)()
for item in obj: for item in obj:
if isinstance(item, list): if isinstance(item, list):
_flatten_list(item, out) flatten_list(item, out)
else: else:
out.append(item) out.append(item)
return out return out
...@@ -57,9 +60,9 @@ def make_attr_name(name): ...@@ -57,9 +60,9 @@ def make_attr_name(name):
make a valid code name for ParamAttr make a valid code name for ParamAttr
""" """
if name == '': assert name != '', 'name should not be empty'
raise ValueError('name should not be empty')
for s in ' *?\\/-:': # for s in ' \\|/:.-': #
name = name.replace(s, '_') name = name.replace(s, '_')
if not name.startswith('_'): if not name.startswith('_'):
name = '_' + name name = '_' + name
...@@ -93,7 +96,7 @@ class Program(object): ...@@ -93,7 +96,7 @@ class Program(object):
return Program.DTYPE_TO_FRAMEWORK_DTYPE[dtype] return Program.DTYPE_TO_FRAMEWORK_DTYPE[dtype]
@staticmethod @staticmethod
def OpDescVars(vals, *keys): def OpDescVars(keys, vals):
""" """
make (OpDesc.Var)s make (OpDesc.Var)s
""" """
...@@ -130,8 +133,8 @@ class Program(object): ...@@ -130,8 +133,8 @@ class Program(object):
od_attr.type = framework_pb2.STRING od_attr.type = framework_pb2.STRING
od_attr.s = value od_attr.s = value
elif isinstance(value, list): elif isinstance(value, list):
if len(value) > 0: if value: # TODO: test all items
if isinstance(value, if isinstance(value[0],
bool): # bool.mro() = [bool, int, object] bool): # bool.mro() = [bool, int, object]
od_attr.type = framework_pb2.BOOLEANS od_attr.type = framework_pb2.BOOLEANS
od_attr.bools.extend(value) od_attr.bools.extend(value)
...@@ -147,13 +150,11 @@ class Program(object): ...@@ -147,13 +150,11 @@ class Program(object):
else: else:
raise ValueError('unsupported attribute {} = {}'.format( raise ValueError('unsupported attribute {} = {}'.format(
key, value)) key, value))
else: # WORKAROUND: shape of scalars is [] else: # WORKAROUND: [] not inferred
raise ValueError('unsupported attribute {} = {}'.format( # raise ValueError('unsupported attribute {} = {}'.format(key, value))
key, value)) od_attr.type = framework_pb2.INTS
logger.warning('using attribute %s = %s as INTS', key,
value)
# od_attr.type = framework_pb2.INTS
# logger.warning('using attribute %s = %s as INTS', key, value)
else: else:
raise ValueError('unsupported attribute {} = {}'.format( raise ValueError('unsupported attribute {} = {}'.format(
key, value)) key, value))
...@@ -164,14 +165,15 @@ class Program(object): ...@@ -164,14 +165,15 @@ class Program(object):
self.code_mutable = True self.code_mutable = True
self.codes = [] self.codes = []
self.op_descs = [] self.op_descs = []
self.var_descs = [] self.var_descs = Dict()
def __repr__(self): def __repr__(self):
return ('Program(code mutable: {}) with:\n' return ('Program(code mutable: {}) with:\n'
'codes: {}\n' 'codes: {}\n'
'op_descs: {}\n' 'op_descs: {}\n'
'var_descs: {}\n').format(self.code_mutable, self.codes, 'var_descs: {}\n').format(self.code_mutable, self.codes,
self.op_descs, self.var_descs) self.op_descs,
list(self.var_descs.values()))
def Code(self, code): def Code(self, code):
""" """
...@@ -181,23 +183,16 @@ class Program(object): ...@@ -181,23 +183,16 @@ class Program(object):
if self.code_mutable: if self.code_mutable:
self.codes.append(code) self.codes.append(code)
def OpDesc(self, def OpDesc(self, op_type, input_key_vals, output_key_vals, attrs):
name,
input_val_keys=None,
output_val_keys=None,
attrs=None):
""" """
add OpDesc add OpDesc
""" """
desc = framework_pb2.OpDesc() desc = framework_pb2.OpDesc()
desc.type = name desc.type = op_type
if input_val_keys is not None: desc.inputs.extend(self.OpDescVars(*input_key_vals))
desc.inputs.extend(self.OpDescVars(*input_val_keys)) desc.outputs.extend(self.OpDescVars(*output_key_vals))
if output_val_keys is not None: desc.attrs.extend(self.OpDescAttrs(attrs))
desc.outputs.extend(self.OpDescVars(*output_val_keys))
if attrs is not None:
desc.attrs.extend(self.OpDescAttrs(attrs))
self.op_descs.append(desc) self.op_descs.append(desc)
return desc return desc
...@@ -210,26 +205,18 @@ class Program(object): ...@@ -210,26 +205,18 @@ class Program(object):
add VarDesc, add VarDesc,
""" """
assert name not in self.var_descs, 'var name {} conflicts'.format(name)
var_desc = framework_pb2.VarDesc() var_desc = framework_pb2.VarDesc()
var_desc.name = name var_desc.name = name
var_desc.persistable = persistable var_desc.persistable = persistable
var_desc.type.type = framework_pb2.VarType.LOD_TENSOR var_desc.type.type = framework_pb2.VarType.LOD_TENSOR
self.var_descs[name] = var_desc
if value_info and 'dtype' in value_info: if value_info is not None:
tensor_desc = var_desc.type.lod_tensor.tensor self.VarTypeShapeInfo(name, value_info, remove_batch=remove_batch)
tensor_desc.data_type = self.Dtype(value_info['dtype']) # required
if 'shape' in value_info:
tensor_desc.dims.extend(value_info['shape'])
if len(value_info['shape']) > 0: # skip scalars
if remove_batch is None:
remove_batch = value_info.get('remove_batch',
not persistable)
if remove_batch:
tensor_desc.dims[0] = -1
self.var_descs.append(var_desc)
def Op(self, domain, op_type, *args, **kwargs): def Op(self, domain, op_type, inputs, outputs, attrs, *args, **kwargs):
""" """
convert an ONNX op and add it to program convert an ONNX op and add it to program
""" """
...@@ -238,15 +225,17 @@ class Program(object): ...@@ -238,15 +225,17 @@ class Program(object):
raise ValueError('only default domain supported') raise ValueError('only default domain supported')
if op_type in symbolic.DEFAULT_OP_MAPPING: if op_type in symbolic.DEFAULT_OP_MAPPING:
symbolic._default(self, op_type, *args, **kwargs) symbolic._default(self, op_type, inputs, outputs, attrs, *args,
**kwargs)
elif hasattr(symbolic, op_type): elif hasattr(symbolic, op_type):
fn = getattr(symbolic, op_type) fn = getattr(symbolic, op_type)
fn(self, *args, **kwargs) fn(self, inputs, outputs, attrs, *args, **kwargs)
else: else:
raise ValueError('conversion for {}::{} not supported'.format( raise ValueError('conversion for {}::{} not supported'.format(
domain, op_type)) domain, op_type))
def IntermediateOp(self, domain, op_type, *args, **kwargs): def IntermediateOp(self, domain, op_type, inputs, outputs, attrs, *args,
**kwargs):
""" """
convert an intermediate ONNX op declaring in desc program only convert an intermediate ONNX op declaring in desc program only
""" """
...@@ -254,20 +243,47 @@ class Program(object): ...@@ -254,20 +243,47 @@ class Program(object):
code_mutable = self.code_mutable code_mutable = self.code_mutable
self.code_mutable = False self.code_mutable = False
try: try:
self.Op(domain, op_type, *args, **kwargs) self.Op(domain, op_type, inputs, outputs, attrs, *args, **kwargs)
except BaseException as e: except BaseException as e:
self.code_mutable = code_mutable self.code_mutable = code_mutable
raise e raise e
else: else:
self.code_mutable = code_mutable self.code_mutable = code_mutable
def VarTypeShapeInfo(self, name, value_info, remove_batch=None):
"""
set value_info for var
"""
if name not in self.var_descs:
return
dtype = value_info.get('dtype', None)
if dtype is None:
return
var_desc = self.var_descs[name]
tensor_desc = var_desc.type.lod_tensor.tensor
tensor_desc.data_type = self.Dtype(dtype) # required
shape = value_info.get('shape', None)
if not shape: # None or scalars
return
tensor_desc.dims.extend(shape)
if remove_batch is None:
remove_batch = value_info.get('remove_batch',
False) #not persistable)
if remove_batch:
tensor_desc.dims[0] = -1
class Writer(object): class Writer(object):
""" """
fluid code and desc writter fluid code and desc writter
""" """
CODE_INDENT = ' ' * 4 CODE_INDENT = ' ' * 4 # '\t'
@staticmethod @staticmethod
def header_code(func_name, info=''): def header_code(func_name, info=''):
...@@ -275,7 +291,7 @@ class Writer(object): ...@@ -275,7 +291,7 @@ class Writer(object):
Python header codes Python header codes
""" """
codes = list() codes = []
codes.append('"""') codes.append('"""')
codes.append('This code is generated by onnx2fluid.') codes.append('This code is generated by onnx2fluid.')
codes.append('{}'.format(info)) codes.append('{}'.format(info))
...@@ -287,6 +303,7 @@ class Writer(object): ...@@ -287,6 +303,7 @@ class Writer(object):
codes.append('from paddle.fluid import initializer, layers') codes.append('from paddle.fluid import initializer, layers')
codes.append('') codes.append('')
codes.append('') codes.append('')
codes.append('def {}():'.format(func_name)) codes.append('def {}():'.format(func_name))
return codes return codes
...@@ -299,17 +316,16 @@ class Writer(object): ...@@ -299,17 +316,16 @@ class Writer(object):
prog.Code('# {}, {}::{}: {} -> {}, {}'.format(name, domain, op_type, prog.Code('# {}, {}::{}: {} -> {}, {}'.format(name, domain, op_type,
inputs, outputs, inputs, outputs,
_irepr(attrs, to=', '))) irepr(attrs, to=', ')))
prog.Op( prog.Op(domain,
domain, op_type,
op_type, inputs,
inputs, outputs,
outputs, attrs,
attrs, value_infos=value_infos,
value_infos=value_infos, name=name,
name=name, *args,
*args, **kwargs)
**kwargs)
@staticmethod @staticmethod
def emit_param(prog, name, value_info): def emit_param(prog, name, value_info):
...@@ -317,24 +333,26 @@ class Writer(object): ...@@ -317,24 +333,26 @@ class Writer(object):
emit an ONNX weight into program emit an ONNX weight into program
""" """
if value_info.get('embeded_as', []): embedded_names = value_info.get('embedded_as', [])
var_names = value_info['embeded_as'] if embedded_names:
prog.Code('# parameter {} embeded as {}'.format(name, var_names)) prog.Code('# parameter {} embedded as {}'.format(
for var_name in var_names: name, embedded_names))
prog.VarDesc(var_name, persistable=True, value_info=value_info) for embedded_name in embedded_names:
prog.VarDesc(embedded_name,
persistable=True,
value_info=value_info)
else: else:
var_name = make_var_name(name)
attr_name = make_attr_name(name) attr_name = make_attr_name(name)
prog.Code('# parameter {}: {}'.format(name, var_name)) prog.Code('# parameter {}'.format(name))
prog.Code('{} = ParamAttr(name={})' # , trainable=True prog.Code('{} = ParamAttr(name={})' # , trainable=True
.format(attr_name, repr(var_name))) .format(attr_name, repr(name)))
prog.Code( prog.Code(
'{} = layers.create_parameter(shape={}, dtype={}, name={}, attr={}' '{} = layers.create_parameter(shape={}, dtype={}, name={}, attr={}'
', default_initializer=initializer.Constant(0))' #, is_bias={} ', default_initializer=initializer.Constant(0))' #, is_bias={}
.format(var_name, value_info['shape'], .format(name, value_info['shape'],
repr(value_info['dtype'].name), repr(name), repr(value_info['dtype'].name), repr(name),
attr_name)) #, value_info.get('is_bias', False))) attr_name)) #, value_info.get('is_bias', False)))
prog.VarDesc(var_name, persistable=True, value_info=value_info) prog.VarDesc(name, persistable=True, value_info=value_info)
@staticmethod @staticmethod
def emit_inputs(prog, names, value_infos, remove_batch=None): def emit_inputs(prog, names, value_infos, remove_batch=None):
...@@ -343,7 +361,6 @@ class Writer(object): ...@@ -343,7 +361,6 @@ class Writer(object):
""" """
for idx, name in enumerate(names): for idx, name in enumerate(names):
var_name = make_var_name(name)
value_info = value_infos[name] value_info = value_infos[name]
shape = value_info['shape'] shape = value_info['shape']
if remove_batch is None: if remove_batch is None:
...@@ -352,25 +369,24 @@ class Writer(object): ...@@ -352,25 +369,24 @@ class Writer(object):
if remove_batch: if remove_batch:
shape = shape[1:] shape = shape[1:]
prog.Code('# input {}: {}'.format(name, var_name)) prog.Code('# input {}'.format(name))
prog.Code(( prog.Code((
'{} = layers.data(name={}, shape={}, dtype={}, ' '{} = layers.data(name={}, shape={}, dtype={}, '
'append_batch_size={})' # , stop_gradient=True 'append_batch_size={})' # , stop_gradient=True
).format( ).format(
var_name, name,
repr(var_name), repr(name),
shape, shape,
repr(value_info['dtype'].name), repr(value_info['dtype'].name),
remove_batch, remove_batch,
)) ))
prog.OpDesc( prog.OpDesc(
'feed', 'feed',
(['feed'], 'X'), (['X'], ['feed']),
([var_name], 'Out'), (['Out'], [name]),
dict(col=idx), {'col': idx},
) )
prog.VarDesc( prog.VarDesc(name, value_info=value_info, remove_batch=remove_batch)
var_name, value_info=value_info, remove_batch=remove_batch)
@staticmethod @staticmethod
def emit_outputs(prog, names): #, value_infos def emit_outputs(prog, names): #, value_infos
...@@ -380,14 +396,13 @@ class Writer(object): ...@@ -380,14 +396,13 @@ class Writer(object):
code = 'return ' code = 'return '
for idx, name in enumerate(names): for idx, name in enumerate(names):
var_name = make_var_name(name) code += name + ', '
code += var_name + ', '
prog.OpDesc( prog.OpDesc(
'fetch', 'fetch',
([var_name], 'X'), (['X'], [name]),
(['fetch'], 'Out'), (['Out'], ['fetch']),
dict(col=idx), {'col': idx},
) )
# var is emitted over ops # var is emitted over ops
prog.Code(code) prog.Code(code)
...@@ -398,18 +413,22 @@ class Writer(object): ...@@ -398,18 +413,22 @@ class Writer(object):
flatten codes in program flatten codes in program
""" """
for code in _flatten_list(others): for code in flatten_list(others):
codes.append(Writer.CODE_INDENT * indent + code) codes.append(Writer.CODE_INDENT * indent + code)
return codes return codes
@staticmethod @staticmethod
def write_weight(weight, filename): def write_weight(weight, filename, lod=None):
""" """
write single weight in fluid desc write single weight in fluid desc
""" """
if not isinstance(weight, np.ndarray): assert isinstance(weight, np.ndarray), 'weight is not an ndarray'
raise TypeError('weight is not an ndarray') assert lod is None or isinstance(lod,
list), 'lod should be None or list'
if lod is None:
lod = [0]
tensor_desc = framework_pb2.VarType.TensorDesc() tensor_desc = framework_pb2.VarType.TensorDesc()
tensor_desc.data_type = Program.Dtype(weight.dtype) tensor_desc.data_type = Program.Dtype(weight.dtype)
...@@ -417,7 +436,7 @@ class Writer(object): ...@@ -417,7 +436,7 @@ class Writer(object):
fp = open(filename, 'wb') fp = open(filename, 'wb')
np.array([0], dtype=np.int32).tofile(fp) # version np.array([0], dtype=np.int32).tofile(fp) # version
np.array([0], dtype=np.int64).tofile(fp) # LOD level np.array(lod, dtype=np.int64).tofile(fp) # LOD level
np.array([0], dtype=np.int32).tofile(fp) # tensor version np.array([0], dtype=np.int32).tofile(fp) # tensor version
np.array([tensor_desc.ByteSize()], dtype=np.int32).tofile(fp) np.array([tensor_desc.ByteSize()], dtype=np.int32).tofile(fp)
fp.write(tensor_desc.SerializeToString()) fp.write(tensor_desc.SerializeToString())
...@@ -431,11 +450,9 @@ class Writer(object): ...@@ -431,11 +450,9 @@ class Writer(object):
""" """
for name, weight in weights.items(): for name, weight in weights.items():
if not isinstance(weights, dict): assert isinstance(weights, dict), 'dict type weights required'
raise TypeError('dict type weights required')
var_name = make_var_name(name) filename = os.path.join(save_dir, name)
filename = os.path.join(save_dir, var_name)
Writer.write_weight(weight, filename) Writer.write_weight(weight, filename)
logger.debug('saved weight %s to %s', name, filename) logger.debug('saved weight %s to %s', name, filename)
...@@ -451,7 +468,7 @@ class Writer(object): ...@@ -451,7 +468,7 @@ class Writer(object):
Writer.add_codes(codes, body_code, 1) Writer.add_codes(codes, body_code, 1)
fp = open(filename, 'w') fp = open(filename, 'w')
for code in _flatten_list(codes): for code in flatten_list(codes):
fp.write(code) fp.write(code)
fp.write('\n') fp.write('\n')
fp.close() fp.close()
......
-e . -e .
onnx>=1.4 onnx>=1.4
paddlepaddle paddlepaddle>=1.5
...@@ -19,13 +19,13 @@ license = MIT ...@@ -19,13 +19,13 @@ license = MIT
# 从PyPI官方给出的列表中选择符合的内容进行填写 # 从PyPI官方给出的列表中选择符合的内容进行填写
# https://pypi.org/pypi?%3Aaction=list_classifiers # https://pypi.org/pypi?%3Aaction=list_classifiers
classifier = classifier =
Private :: Do Not Upload Private :: Do Not Upload
Programming Language :: Python Programming Language :: Python
Programming Language :: Python :: 3 Programming Language :: Python :: 3
Programming Language :: Python :: 3.5 Programming Language :: Python :: 3.5
# 关键字,用于检索,方便用户搜索到你的项目 # 关键字,用于检索,方便用户搜索到你的项目
keywords = keywords =
onnx paddlepaddle onnx paddlepaddle
[options] [options]
# 包名称,find:表示自动寻找,可在options.packages.find中进行详细配置 # 包名称,find:表示自动寻找,可在options.packages.find中进行详细配置
...@@ -34,7 +34,7 @@ packages = find: ...@@ -34,7 +34,7 @@ packages = find:
# 每行一个依赖库,只写直接依赖,通常无需考虑间接依赖 # 每行一个依赖库,只写直接依赖,通常无需考虑间接依赖
# 在这里指定的版本限制应当尽量抽象,通常只要指定最低版本和大版本号即可 # 在这里指定的版本限制应当尽量抽象,通常只要指定最低版本和大版本号即可
install_requires = install_requires =
onnx >= 1.4 onnx >= 1.4
# 测试依赖,包含项目测试时所需要的额外的依赖库,格式与install_requires一致 # 测试依赖,包含项目测试时所需要的额外的依赖库,格式与install_requires一致
# 可以使用内置的unittest,也可以使用更简单的pytest或nose等单测框架 # 可以使用内置的unittest,也可以使用更简单的pytest或nose等单测框架
...@@ -53,7 +53,9 @@ zip_safe = True ...@@ -53,7 +53,9 @@ zip_safe = True
# 可以通过以下配置将指定的函数变成命令行工具,允许用户直接执行 # 可以通过以下配置将指定的函数变成命令行工具,允许用户直接执行
[options.entry_points] [options.entry_points]
console_scripts = console_scripts =
onnx2fluid = onnx2fluid.__main__ onnx2fluid = onnx2fluid.__main__
onnx2fluid_convert = onnx2fluid.conversion:main
onnx2fluid_validate = onnx2fluid.validation:main
# 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下 # 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下
# 仅支持文件,不支持目录,但可以使用通配 # 仅支持文件,不支持目录,但可以使用通配
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册