提交 7c3e9379 编写于 作者: M Macrobull

bugfix

上级 816ac6e2
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 27 11:50:03 2019
@author: Macrobull
"""
import sys
import numpy as np
from collections import OrderedDict as Dict
def make_var_name(name):
"""
make a valid variable name in Python code
"""
if name == '':
return '_'
if name[0].isdigit():
return 'var_' + name
for s in ' \\|/:': #
name = name.replace(s, '_')
if name.startswith('_'):
name = 'var' + name
return name
fn = sys.argv[1]
input_names = sys.argv[2].split(',')
output_names = sys.argv[3].split(',')
squeeze_data = len(sys.argv) > 4
data = np.load(fn, encoding='bytes')
input_data = data['inputs']
output_data = data['outputs']
while squeeze_data and input_data.ndim > 4 and input_data.shape[0] == 1:
input_data = input_data.squeeze(0)
while squeeze_data and output_data.ndim > 2 and output_data.shape[0] == 1:
output_data = output_data.squeeze(0)
inputs = Dict(zip(map(make_var_name, input_names), [input_data]))
outputs = Dict(zip(map(make_var_name, output_names), [output_data]))
np.savez(fn, inputs=inputs, outputs=outputs) # overwrite
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 27 11:50:03 2019
@author: Macrobull
"""
import os, sys
import numpy as np
import onnx
import onnx.numpy_helper as numpy_helper
from collections import OrderedDict as Dict
from glob import glob
def make_var_name(name):
"""
make a valid variable name in Python code
"""
if name == '':
return '_'
if name[0].isdigit():
return 'var_' + name
for s in ' \\|/:': #
name = name.replace(s, '_')
if name.startswith('_'):
name = 'var' + name
return name
data_dir = os.path.dirname(sys.argv[1])
input_names = sys.argv[2].split(',')
output_names = sys.argv[3].split(',')
squeeze_data = len(sys.argv) > 4
# Load inputs
inputs = []
for fn in glob(os.path.join(data_dir, 'input_*.pb')):
tensor = onnx.TensorProto()
with open(fn, 'rb') as f:
tensor.ParseFromString(f.read())
tensor = numpy_helper.to_array(tensor)
while squeeze_data and tensor.ndim > 4 and tensor.shape[0] == 1:
tensor = tensor.squeeze(0)
inputs.append(tensor)
# Load outputs
outputs = []
for fn in glob(os.path.join(data_dir, 'output_*.pb')):
tensor = onnx.TensorProto()
with open(fn, 'rb') as f:
tensor.ParseFromString(f.read())
tensor = numpy_helper.to_array(tensor)
while squeeze_data and tensor.ndim > 2 and tensor.shape[0] == 1:
tensor = tensor.squeeze(0)
outputs.append(tensor)
inputs = Dict(zip(map(make_var_name, input_names), inputs))
outputs = Dict(zip(map(make_var_name, output_names), outputs))
np.savez(data_dir, inputs=inputs, outputs=outputs)
...@@ -39,7 +39,7 @@ idx = 0 ...@@ -39,7 +39,7 @@ idx = 0
#yp = model(xb) #yp = model(xb)
#idx += 1 #idx += 1
#print('index: ', idx) #print('index: ', idx)
#export_onnx_with_validation(model, (xb, ), prefix + str(idx), #export_onnx_with_validation(model, [xb], prefix + str(idx),
# ['x'], ['y'], # ['x'], ['y'],
# verbose=True, training=False) # verbose=True, training=False)
...@@ -61,7 +61,7 @@ idx = 0 ...@@ -61,7 +61,7 @@ idx = 0
#yp = model(xb) #yp = model(xb)
#idx += 1 #idx += 1
#print('index: ', idx) #print('index: ', idx)
#export_onnx_with_validation(model, (xb, ), prefix + str(idx), #export_onnx_with_validation(model, [xb], prefix + str(idx),
# ['x'], ['y'], # ['x'], ['y'],
# verbose=True, training=False) # verbose=True, training=False)
...@@ -85,8 +85,7 @@ xb = torch.rand((2, 3)) ...@@ -85,8 +85,7 @@ xb = torch.rand((2, 3))
yp = model(xb) yp = model(xb)
idx += 1 idx += 1
print('index: ', idx) print('index: ', idx)
export_onnx_with_validation( export_onnx_with_validation(model, [xb],
model, (xb, ),
prefix + str(idx), ['x'], ['y'], prefix + str(idx), ['x'], ['y'],
verbose=True, verbose=True,
training=False) training=False)
...@@ -113,8 +112,7 @@ xb1 = torch.rand((2, 3)) ...@@ -113,8 +112,7 @@ xb1 = torch.rand((2, 3))
ya, yb, yc = model(xb0, xb1) ya, yb, yc = model(xb0, xb1)
idx += 1 idx += 1
print('index: ', idx) print('index: ', idx)
export_onnx_with_validation( export_onnx_with_validation(model, [xb0, xb1],
model, (xb0, xb1),
prefix + str(idx), ['x0', 'x1'], ['ya', 'yb', 'yc'], prefix + str(idx), ['x0', 'x1'], ['ya', 'yb', 'yc'],
verbose=True, verbose=True,
training=False) training=False)
...@@ -137,8 +135,7 @@ theta = torch.rand((2, 2, 3)) ...@@ -137,8 +135,7 @@ theta = torch.rand((2, 2, 3))
grid = model(theta) grid = model(theta)
idx += 1 idx += 1
print('index: ', idx) print('index: ', idx)
export_onnx_with_validation( export_onnx_with_validation(model, (theta, ),
model, (theta, ),
prefix + str(idx), ['theta'], ['grid'], prefix + str(idx), ['theta'], ['grid'],
verbose=True, verbose=True,
training=False) training=False)
...@@ -165,8 +162,7 @@ xb = torch.rand((2, 3, 4, 5)) ...@@ -165,8 +162,7 @@ xb = torch.rand((2, 3, 4, 5))
yp = model(xb) yp = model(xb)
idx += 1 idx += 1
print('index: ', idx) print('index: ', idx)
export_onnx_with_validation( export_onnx_with_validation(model, [xb],
model, (xb, ),
prefix + str(idx), ['x'], ['y'], prefix + str(idx), ['x'], ['y'],
verbose=True, verbose=True,
training=False) training=False)
...@@ -195,8 +191,7 @@ xb = torch.rand((2, 3, 4, 5)) ...@@ -195,8 +191,7 @@ xb = torch.rand((2, 3, 4, 5))
yp = model(xb) yp = model(xb)
idx += 1 idx += 1
print('index: ', idx) print('index: ', idx)
export_onnx_with_validation( export_onnx_with_validation(model, [xb],
model, (xb, ),
prefix + str(idx), ['x'], ['y'], prefix + str(idx), ['x'], ['y'],
verbose=True, verbose=True,
training=False) training=False)
...@@ -220,7 +215,7 @@ export_onnx_with_validation( ...@@ -220,7 +215,7 @@ export_onnx_with_validation(
#yp = model(xb) #yp = model(xb)
#idx += 1 #idx += 1
#print('index: ', idx) #print('index: ', idx)
#export_onnx_with_validation(model, (xb, ), prefix + str(idx), #export_onnx_with_validation(model, [xb], prefix + str(idx),
# ['x'], ['y'], # ['x'], ['y'],
# verbose=True, training=False) # verbose=True, training=False)
...@@ -241,8 +236,7 @@ xb = torch.rand((2, 3)) ...@@ -241,8 +236,7 @@ xb = torch.rand((2, 3))
yp = model(xb) yp = model(xb)
idx += 1 idx += 1
print('index: ', idx) print('index: ', idx)
export_onnx_with_validation( export_onnx_with_validation(model, [xb],
model, (xb, ),
prefix + str(idx), ['y'], ['y'], prefix + str(idx), ['y'], ['y'],
verbose=True, verbose=True,
training=False) training=False)
...@@ -21,9 +21,9 @@ class double_conv(nn.Module): ...@@ -21,9 +21,9 @@ class double_conv(nn.Module):
def __init__(self, in_ch, out_ch): def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__() super(double_conv, self).__init__()
self.conv = nn.Sequential( self.conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True)) nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
def forward(self, x): def forward(self, x):
...@@ -58,8 +58,8 @@ class up(nn.Module): ...@@ -58,8 +58,8 @@ class up(nn.Module):
# would be a nice idea if the upsampling could be learned too, # would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights # but my machine do not have enough memory to handle all those weights
if bilinear: if bilinear:
self.up = nn.Upsample( self.up = nn.Upsample(scale_factor=2,
scale_factor=2, mode='bilinear') #, align_corners=True) mode='bilinear') #, align_corners=True)
else: else:
self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
...@@ -131,8 +131,7 @@ model = UNet(3, 80) ...@@ -131,8 +131,7 @@ model = UNet(3, 80)
model.eval() model.eval()
xb = torch.rand((1, 3, 512, 512)) xb = torch.rand((1, 3, 512, 512))
yp = model(xb) yp = model(xb)
export_onnx_with_validation( export_onnx_with_validation(model, [xb],
model, (xb, ),
'sample_unet', ['image'], ['pred'], 'sample_unet', ['image'], ['pred'],
verbose=True, verbose=True,
training=False) training=False)
...@@ -20,8 +20,7 @@ class Yolov2(nn.Module): ...@@ -20,8 +20,7 @@ class Yolov2(nn.Module):
def __init__(self): def __init__(self):
super(Yolov2, self).__init__() super(Yolov2, self).__init__()
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(in_channels=3,
in_channels=3,
out_channels=32, out_channels=32,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
...@@ -29,8 +28,7 @@ class Yolov2(nn.Module): ...@@ -29,8 +28,7 @@ class Yolov2(nn.Module):
bias=False) bias=False)
self.batchnorm1 = nn.BatchNorm2d(32) self.batchnorm1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(in_channels=32,
in_channels=32,
out_channels=64, out_channels=64,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
...@@ -38,24 +36,21 @@ class Yolov2(nn.Module): ...@@ -38,24 +36,21 @@ class Yolov2(nn.Module):
bias=False) bias=False)
self.batchnorm2 = nn.BatchNorm2d(64) self.batchnorm2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d( self.conv3 = nn.Conv2d(in_channels=64,
in_channels=64,
out_channels=128, out_channels=128,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
bias=False) bias=False)
self.batchnorm3 = nn.BatchNorm2d(128) self.batchnorm3 = nn.BatchNorm2d(128)
self.conv4 = nn.Conv2d( self.conv4 = nn.Conv2d(in_channels=128,
in_channels=128,
out_channels=64, out_channels=64,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0,
bias=False) bias=False)
self.batchnorm4 = nn.BatchNorm2d(64) self.batchnorm4 = nn.BatchNorm2d(64)
self.conv5 = nn.Conv2d( self.conv5 = nn.Conv2d(in_channels=64,
in_channels=64,
out_channels=128, out_channels=128,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
...@@ -63,24 +58,21 @@ class Yolov2(nn.Module): ...@@ -63,24 +58,21 @@ class Yolov2(nn.Module):
bias=False) bias=False)
self.batchnorm5 = nn.BatchNorm2d(128) self.batchnorm5 = nn.BatchNorm2d(128)
self.conv6 = nn.Conv2d( self.conv6 = nn.Conv2d(in_channels=128,
in_channels=128,
out_channels=256, out_channels=256,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
bias=False) bias=False)
self.batchnorm6 = nn.BatchNorm2d(256) self.batchnorm6 = nn.BatchNorm2d(256)
self.conv7 = nn.Conv2d( self.conv7 = nn.Conv2d(in_channels=256,
in_channels=256,
out_channels=128, out_channels=128,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0,
bias=False) bias=False)
self.batchnorm7 = nn.BatchNorm2d(128) self.batchnorm7 = nn.BatchNorm2d(128)
self.conv8 = nn.Conv2d( self.conv8 = nn.Conv2d(in_channels=128,
in_channels=128,
out_channels=256, out_channels=256,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
...@@ -88,40 +80,35 @@ class Yolov2(nn.Module): ...@@ -88,40 +80,35 @@ class Yolov2(nn.Module):
bias=False) bias=False)
self.batchnorm8 = nn.BatchNorm2d(256) self.batchnorm8 = nn.BatchNorm2d(256)
self.conv9 = nn.Conv2d( self.conv9 = nn.Conv2d(in_channels=256,
in_channels=256,
out_channels=512, out_channels=512,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
bias=False) bias=False)
self.batchnorm9 = nn.BatchNorm2d(512) self.batchnorm9 = nn.BatchNorm2d(512)
self.conv10 = nn.Conv2d( self.conv10 = nn.Conv2d(in_channels=512,
in_channels=512,
out_channels=256, out_channels=256,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0,
bias=False) bias=False)
self.batchnorm10 = nn.BatchNorm2d(256) self.batchnorm10 = nn.BatchNorm2d(256)
self.conv11 = nn.Conv2d( self.conv11 = nn.Conv2d(in_channels=256,
in_channels=256,
out_channels=512, out_channels=512,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
bias=False) bias=False)
self.batchnorm11 = nn.BatchNorm2d(512) self.batchnorm11 = nn.BatchNorm2d(512)
self.conv12 = nn.Conv2d( self.conv12 = nn.Conv2d(in_channels=512,
in_channels=512,
out_channels=256, out_channels=256,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0,
bias=False) bias=False)
self.batchnorm12 = nn.BatchNorm2d(256) self.batchnorm12 = nn.BatchNorm2d(256)
self.conv13 = nn.Conv2d( self.conv13 = nn.Conv2d(in_channels=256,
in_channels=256,
out_channels=512, out_channels=512,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
...@@ -129,40 +116,35 @@ class Yolov2(nn.Module): ...@@ -129,40 +116,35 @@ class Yolov2(nn.Module):
bias=False) bias=False)
self.batchnorm13 = nn.BatchNorm2d(512) self.batchnorm13 = nn.BatchNorm2d(512)
self.conv14 = nn.Conv2d( self.conv14 = nn.Conv2d(in_channels=512,
in_channels=512,
out_channels=1024, out_channels=1024,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
bias=False) bias=False)
self.batchnorm14 = nn.BatchNorm2d(1024) self.batchnorm14 = nn.BatchNorm2d(1024)
self.conv15 = nn.Conv2d( self.conv15 = nn.Conv2d(in_channels=1024,
in_channels=1024,
out_channels=512, out_channels=512,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0,
bias=False) bias=False)
self.batchnorm15 = nn.BatchNorm2d(512) self.batchnorm15 = nn.BatchNorm2d(512)
self.conv16 = nn.Conv2d( self.conv16 = nn.Conv2d(in_channels=512,
in_channels=512,
out_channels=1024, out_channels=1024,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
bias=False) bias=False)
self.batchnorm16 = nn.BatchNorm2d(1024) self.batchnorm16 = nn.BatchNorm2d(1024)
self.conv17 = nn.Conv2d( self.conv17 = nn.Conv2d(in_channels=1024,
in_channels=1024,
out_channels=512, out_channels=512,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0,
bias=False) bias=False)
self.batchnorm17 = nn.BatchNorm2d(512) self.batchnorm17 = nn.BatchNorm2d(512)
self.conv18 = nn.Conv2d( self.conv18 = nn.Conv2d(in_channels=512,
in_channels=512,
out_channels=1024, out_channels=1024,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
...@@ -170,16 +152,14 @@ class Yolov2(nn.Module): ...@@ -170,16 +152,14 @@ class Yolov2(nn.Module):
bias=False) bias=False)
self.batchnorm18 = nn.BatchNorm2d(1024) self.batchnorm18 = nn.BatchNorm2d(1024)
self.conv19 = nn.Conv2d( self.conv19 = nn.Conv2d(in_channels=1024,
in_channels=1024,
out_channels=1024, out_channels=1024,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
bias=False) bias=False)
self.batchnorm19 = nn.BatchNorm2d(1024) self.batchnorm19 = nn.BatchNorm2d(1024)
self.conv20 = nn.Conv2d( self.conv20 = nn.Conv2d(in_channels=1024,
in_channels=1024,
out_channels=1024, out_channels=1024,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
...@@ -187,8 +167,7 @@ class Yolov2(nn.Module): ...@@ -187,8 +167,7 @@ class Yolov2(nn.Module):
bias=False) bias=False)
self.batchnorm20 = nn.BatchNorm2d(1024) self.batchnorm20 = nn.BatchNorm2d(1024)
self.conv21 = nn.Conv2d( self.conv21 = nn.Conv2d(in_channels=3072,
in_channels=3072,
out_channels=1024, out_channels=1024,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
...@@ -196,8 +175,7 @@ class Yolov2(nn.Module): ...@@ -196,8 +175,7 @@ class Yolov2(nn.Module):
bias=False) bias=False)
self.batchnorm21 = nn.BatchNorm2d(1024) self.batchnorm21 = nn.BatchNorm2d(1024)
self.conv22 = nn.Conv2d( self.conv22 = nn.Conv2d(in_channels=1024,
in_channels=1024,
out_channels=125, out_channels=125,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
...@@ -227,12 +205,12 @@ class Yolov2(nn.Module): ...@@ -227,12 +205,12 @@ class Yolov2(nn.Module):
return passthrough return passthrough
def forward(self, x): def forward(self, x):
out = F.max_pool2d( out = F.max_pool2d(F.leaky_relu(self.batchnorm1(self.conv1(x)),
F.leaky_relu(self.batchnorm1(self.conv1(x)), negative_slope=0.1), negative_slope=0.1),
2, 2,
stride=2) stride=2)
out = F.max_pool2d( out = F.max_pool2d(F.leaky_relu(self.batchnorm2(self.conv2(out)),
F.leaky_relu(self.batchnorm2(self.conv2(out)), negative_slope=0.1), negative_slope=0.1),
2, 2,
stride=2) stride=2)
...@@ -247,36 +225,36 @@ class Yolov2(nn.Module): ...@@ -247,36 +225,36 @@ class Yolov2(nn.Module):
out = F.max_pool2d(out, 2, stride=2) out = F.max_pool2d(out, 2, stride=2)
out = F.leaky_relu(self.batchnorm9(self.conv9(out)), negative_slope=0.1) out = F.leaky_relu(self.batchnorm9(self.conv9(out)), negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm10(self.conv10(out)),
self.batchnorm10(self.conv10(out)), negative_slope=0.1) negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm11(self.conv11(out)),
self.batchnorm11(self.conv11(out)), negative_slope=0.1) negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm12(self.conv12(out)),
self.batchnorm12(self.conv12(out)), negative_slope=0.1) negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm13(self.conv13(out)),
self.batchnorm13(self.conv13(out)), negative_slope=0.1) negative_slope=0.1)
passthrough = self.reorg_layer(out) passthrough = self.reorg_layer(out)
out = F.max_pool2d(out, 2, stride=2) out = F.max_pool2d(out, 2, stride=2)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm14(self.conv14(out)),
self.batchnorm14(self.conv14(out)), negative_slope=0.1) negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm15(self.conv15(out)),
self.batchnorm15(self.conv15(out)), negative_slope=0.1) negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm16(self.conv16(out)),
self.batchnorm16(self.conv16(out)), negative_slope=0.1) negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm17(self.conv17(out)),
self.batchnorm17(self.conv17(out)), negative_slope=0.1) negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm18(self.conv18(out)),
self.batchnorm18(self.conv18(out)), negative_slope=0.1) negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm19(self.conv19(out)),
self.batchnorm19(self.conv19(out)), negative_slope=0.1) negative_slope=0.1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm20(self.conv20(out)),
self.batchnorm20(self.conv20(out)), negative_slope=0.1) negative_slope=0.1)
out = torch.cat([passthrough, out], 1) out = torch.cat([passthrough, out], 1)
out = F.leaky_relu( out = F.leaky_relu(self.batchnorm21(self.conv21(out)),
self.batchnorm21(self.conv21(out)), negative_slope=0.1) negative_slope=0.1)
out = self.conv22(out) out = self.conv22(out)
return out return out
...@@ -286,8 +264,7 @@ model = Yolov2() ...@@ -286,8 +264,7 @@ model = Yolov2()
model.eval() model.eval()
xb = torch.rand((1, 3, 224, 224)) xb = torch.rand((1, 3, 224, 224))
yp = model(xb) yp = model(xb)
export_onnx_with_validation( export_onnx_with_validation(model, [xb],
model, (xb, ),
'sample_yolov2', ['image'], ['pred'], 'sample_yolov2', ['image'], ['pred'],
verbose=True, verbose=True,
training=False) training=False)
...@@ -11,6 +11,7 @@ validate_flags2="/tmp/export/__model__" ...@@ -11,6 +11,7 @@ validate_flags2="/tmp/export/__model__"
alias http_get="aria2c -c -s8 -x8" alias http_get="aria2c -c -s8 -x8"
# alias python="python3" # if ... # alias python="python3" # if ...
bvlc_alexnet() bvlc_alexnet()
{ {
bn_tar="bvlc_alexnet" bn_tar="bvlc_alexnet"
...@@ -26,17 +27,19 @@ bvlc_alexnet() ...@@ -26,17 +27,19 @@ bvlc_alexnet()
for npz in "$bn_tar"/*.npz for npz in "$bn_tar"/*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz_0.py "$npz" data_0 prob_1 -s python convert_data_npz.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done done
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar"/*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb_0.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
rm -rf "$bn_tar/"
} }
bvlc_googlenet() bvlc_googlenet()
...@@ -54,10 +57,12 @@ bvlc_googlenet() ...@@ -54,10 +57,12 @@ bvlc_googlenet()
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar"/*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
rm -rf "$bn_tar/"
} }
bvlc_reference_caffenet() bvlc_reference_caffenet()
...@@ -75,10 +80,12 @@ bvlc_reference_caffenet() ...@@ -75,10 +80,12 @@ bvlc_reference_caffenet()
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar"/*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
rm -rf "$bn_tar/"
} }
bvlc_reference_rcnn_ilsvrc13() bvlc_reference_rcnn_ilsvrc13()
...@@ -96,10 +103,65 @@ bvlc_reference_rcnn_ilsvrc13() ...@@ -96,10 +103,65 @@ bvlc_reference_rcnn_ilsvrc13()
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar"/*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" data_0 fc-rcnn_1 python convert_data_pb.py "$pb_dir" data_0 fc-rcnn_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
rm -rf "$bn_tar/"
}
densenet121()
{
bn_tar="densenet121"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model"
for npz in "$bn_tar"/*.npz
do
echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 fc6_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 fc6_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
rm -rf "$bn_tar/"
}
emotion_ferplus()
{
bn_tar="emotion_ferplus"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
http_get "https://onnxzoo.blob.core.windows.net/models/opset_8/emotion_ferplus/$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" Input3 Plus692_Output_0
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
rm -rf "$bn_tar/"
} }
inception_v1() inception_v1()
...@@ -117,17 +179,19 @@ inception_v1() ...@@ -117,17 +179,19 @@ inception_v1()
for npz in "$bn_tar"/*.npz for npz in "$bn_tar"/*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz_0.py "$npz" data_0 prob_1 -s python convert_data_npz.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done done
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar"/*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb_0.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
rm -rf "$bn_tar/"
} }
inception_v2() inception_v2()
...@@ -145,17 +209,65 @@ inception_v2() ...@@ -145,17 +209,65 @@ inception_v2()
for npz in "$bn_tar"/*.npz for npz in "$bn_tar"/*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz_0.py "$npz" data_0 prob_1 -s python convert_data_npz.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done done
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar"/*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb_0.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
rm -rf "$bn_tar/"
}
mobilenet()
{
bn_tar="mobilenetv2-1.0"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/$bn_tar.onnx"
http_get "https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data mobilenetv20_output_flatten0_reshape0
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
rm -rf "$bn_tar/"
}
resnet18()
{
bn_tar="resnet18v1"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/$bn_tar.onnx"
http_get "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet18v1/$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data resnetv15_dense0_fwd
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
rm -rf "$bn_tar/"
} }
resnet50() resnet50()
...@@ -173,17 +285,88 @@ resnet50() ...@@ -173,17 +285,88 @@ resnet50()
for npz in "$bn_tar"/*.npz for npz in "$bn_tar"/*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz_0.py "$npz" gpu_0/data_0 gpu_0/softmaxout_1 -s python convert_data_npz.py "$npz" gpu_0/data_0 gpu_0/softmaxout_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done done
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar"/*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb_0.py "$pb_dir" gpu_0/data_0 gpu_0/softmaxout_1 python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmaxout_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
rm -rf "$bn_tar/"
}
resnet100_arcface()
{
bn_tar="resnet100"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/$bn_tar.onnx"
http_get "https://s3.amazonaws.com/onnx-model-zoo/arcface/resnet100/$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid -o /tmp/export/ "$fn_model" -y
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data fc1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
rm -rf "$bn_tar/"
}
resnet101_duc()
{
bn_tar="ResNet101_DUC_HDC"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/$bn_tar.onnx"
http_get "https://s3.amazonaws.com/onnx-model-zoo/duc/$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data seg_loss
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
rm -rf "$bn_tar/"
}
resnet152()
{
bn_tar="resnet152v2"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/$bn_tar.onnx"
http_get "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet152v2/$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data resnetv27_dense0_fwd
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
rm -rf "$bn_tar/"
} }
shufflenet() shufflenet()
...@@ -201,10 +384,12 @@ shufflenet() ...@@ -201,10 +384,12 @@ shufflenet()
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar"/*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb_0.py "$pb_dir" gpu_0/data_0 gpu_0/softmaxout_1 python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
rm -rf "$bn_tar/"
} }
squeezenet() squeezenet()
...@@ -222,10 +407,59 @@ squeezenet() ...@@ -222,10 +407,59 @@ squeezenet()
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar"/*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" data_0 softmaxout_1 python convert_data_pb.py "$pb_dir" data_0 softmaxout_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
rm -rf "$bn_tar/"
}
squeezenet1v1()
{
bn_tar="squeezenet1.1"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/$bn_tar.onnx"
http_get "https://s3.amazonaws.com/onnx-model-zoo/squeezenet/squeezenet1.1/$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data squeezenet0_flatten0_reshape0
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
rm -rf "$bn_tar/"
}
ssd()
{
bn_tar="ssd"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
http_get "https://onnxzoo.blob.core.windows.net/models/opset_10/ssd/$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
mkdir "$bn_tar"
tar xf "$fn_tar" -C "$bn_tar"/
python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" image bboxes,labels,scores
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
rm -rf "$bn_tar/"
} }
tiny_yolov2() tiny_yolov2()
...@@ -239,14 +473,39 @@ tiny_yolov2() ...@@ -239,14 +473,39 @@ tiny_yolov2()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -xy python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar"/*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" image grid python convert_data_pb.py "$pb_dir" image grid
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
rm -rf "$bn_tar/"
}
vgg16bn()
{
bn_tar="vgg16-bn"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/$bn_tar.onnx"
http_get "https://s3.amazonaws.com/onnx-model-zoo/vgg/vgg16-bn/$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data vgg0_dense2_fwd
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
rm -rf "$bn_tar/"
} }
vgg19() vgg19()
...@@ -264,10 +523,35 @@ vgg19() ...@@ -264,10 +523,35 @@ vgg19()
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar"/*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
rm -rf "$bn_tar/"
}
yolov3()
{
bn_tar="yolov3"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
http_get "https://onnxzoo.blob.core.windows.net/models/opset_10/yolov3/$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -x #
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" input_1:01,image_shape:01 yolonms_layer_1/ExpandDims_1:0,yolonms_layer_1/ExpandDims_3:0,yolonms_layer_1/concat_2:0
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
rm -rf "$bn_tar/"
} }
zfnet512() zfnet512()
...@@ -285,10 +569,12 @@ zfnet512() ...@@ -285,10 +569,12 @@ zfnet512()
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar"/*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1 python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
rm -rf "$bn_tar/"
} }
...@@ -296,11 +582,22 @@ bvlc_alexnet ...@@ -296,11 +582,22 @@ bvlc_alexnet
bvlc_googlenet bvlc_googlenet
bvlc_reference_caffenet bvlc_reference_caffenet
bvlc_reference_rcnn_ilsvrc13 bvlc_reference_rcnn_ilsvrc13
densenet121
emotion_ferplus # not supported
inception_v1 inception_v1
inception_v2 inception_v2
mobilenet
resnet18
resnet50 resnet50
resnet100_arcface
resnet101_duc
resnet152
shufflenet shufflenet
squeezenet # softmax bug squeezenet # softmax bug
# tiny_yolov2 # not supported squeezenet1v1
ssd # version not supported
tiny_yolov2 # not supported
vgg16bn
vgg19 vgg19
yolov3 # malformed model ?
zfnet512 zfnet512
...@@ -92,7 +92,7 @@ parser.add_argument( ...@@ -92,7 +92,7 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
'--rtol', '--rtol',
type=float, type=float,
default=1e-4, default=1e-2,
help='assertion relative tolerance for validation', help='assertion relative tolerance for validation',
) )
args = parser.parse_args() args = parser.parse_args()
......
...@@ -22,7 +22,6 @@ __all__ = [ ...@@ -22,7 +22,6 @@ __all__ = [
'main', 'main',
] ]
DEFAULT_ONNX_OPSET_VERSION = 9
DEFAULT_MODEL_MODULE = 'model' DEFAULT_MODEL_MODULE = 'model'
DEFAULT_MODEL_FUNC = 'inference' DEFAULT_MODEL_FUNC = 'inference'
...@@ -30,6 +29,7 @@ DEFAULT_MODEL_FUNC = 'inference' ...@@ -30,6 +29,7 @@ DEFAULT_MODEL_FUNC = 'inference'
def main(**kwargs): def main(**kwargs):
"""主程序入口""" """主程序入口"""
from .conversion import DEFAULT_ONNX_OPSET_VERSION
from .conversion import convert from .conversion import convert
logger = logging.getLogger('onnx2fluid') logger = logging.getLogger('onnx2fluid')
...@@ -44,9 +44,9 @@ def main(**kwargs): ...@@ -44,9 +44,9 @@ def main(**kwargs):
if save_dir else basepath) + shutil.os.sep if save_dir else basepath) + shutil.os.sep
model_basename = DEFAULT_MODEL_MODULE + '.py' model_basename = DEFAULT_MODEL_MODULE + '.py'
model_func_name = DEFAULT_MODEL_FUNC model_func_name = DEFAULT_MODEL_FUNC
onnx_opset_version = DEFAULT_ONNX_OPSET_VERSION
onnx_opset_pedantic = kwargs.pop('pedantic', True) onnx_opset_pedantic = kwargs.pop('pedantic', True)
onnx_skip_version_conversion = kwargs.pop('skip_version_conversion', False) skip_version_conversion = kwargs.pop('skip_version_conversion', False)
onnx_opset_version = None if skip_version_conversion else DEFAULT_ONNX_OPSET_VERSION
# convert # convert
convert(filename, convert(filename,
...@@ -55,7 +55,6 @@ def main(**kwargs): ...@@ -55,7 +55,6 @@ def main(**kwargs):
model_func_name=model_func_name, model_func_name=model_func_name,
onnx_opset_version=onnx_opset_version, onnx_opset_version=onnx_opset_version,
onnx_opset_pedantic=onnx_opset_pedantic, onnx_opset_pedantic=onnx_opset_pedantic,
onnx_skip_version_conversion=onnx_skip_version_conversion,
**kwargs) **kwargs)
# validate # validate
...@@ -69,12 +68,11 @@ def main(**kwargs): ...@@ -69,12 +68,11 @@ def main(**kwargs):
golden_data_filename, **kwargs) golden_data_filename, **kwargs)
logger.info('starting validation on code ...') logger.info('starting validation on code ...')
passed &= validate( # this re-generate desc proto with Python code when debug on
shutil.os.path.join(save_dir, model_basename), passed &= validate(shutil.os.path.join(save_dir, model_basename),
golden_data_filename, golden_data_filename,
model_func_name=model_func_name, model_func_name=model_func_name,
save_inference_model= save_inference_model=debug,
debug, # re-generate desc proto with python code when debug on
**kwargs) **kwargs)
if not passed: if not passed:
......
...@@ -14,15 +14,16 @@ __all__ = [ ...@@ -14,15 +14,16 @@ __all__ = [
'convert', 'convert',
] ]
DEFAULT_ONNX_OPSET_VERSION = 9
def convert(onnx_model_filename, def convert(onnx_model_filename,
save_dir, save_dir,
model_basename='model.py', model_basename='model.py',
model_func_name='inference', model_func_name='inference',
embed_params=False, embed_params=False,
onnx_opset_version=9, onnx_opset_version=None,
onnx_opset_pedantic=True, onnx_opset_pedantic=True,
onnx_skip_version_conversion=False,
debug=False, debug=False,
**kwargs): **kwargs):
""" """
...@@ -50,11 +51,13 @@ def convert(onnx_model_filename, ...@@ -50,11 +51,13 @@ def convert(onnx_model_filename,
# prepare onnx model # prepare onnx model
logger.info('loading model: %s ...', onnx_model_filename) logger.info('loading model: %s ...', onnx_model_filename)
onnx_model = onnx.load(onnx_model_filename) onnx_model = onnx.load(onnx_model_filename)
try: try:
logger.info('checking model ...') logger.info('checking model ...')
check_model(onnx_model) check_model(onnx_model)
if onnx_skip_version_conversion: # WORKAROUND: RuntimeError: No Adapter For OP if onnx_opset_version is None: # WORKAROUND: RuntimeError: No Adapter For OP
logger.debug('assumed opset version: %d', onnx_opset_version) logger.debug('assumed opset version: %d',
DEFAULT_ONNX_OPSET_VERSION)
logger.warning( logger.warning(
'opset conversion skipped for onnx_opset_pedantic is OFF') 'opset conversion skipped for onnx_opset_pedantic is OFF')
else: else:
...@@ -68,6 +71,7 @@ def convert(onnx_model_filename, ...@@ -68,6 +71,7 @@ def convert(onnx_model_filename,
logger.warning('due to onnx_opset_pedantic is OFF') logger.warning('due to onnx_opset_pedantic is OFF')
logger.warning('the ONNX model sanity checking error is suppressed') logger.warning('the ONNX model sanity checking error is suppressed')
logger.warning('value_info inferring may be uncompleted') logger.warning('value_info inferring may be uncompleted')
# onnx model optimization # onnx model optimization
logger.info('model has %d ops', len(onnx_model.graph.node)) logger.info('model has %d ops', len(onnx_model.graph.node))
logger.info('optimizing model ...') logger.info('optimizing model ...')
...@@ -87,10 +91,7 @@ def convert(onnx_model_filename, ...@@ -87,10 +91,7 @@ def convert(onnx_model_filename,
debug_model_filename, _ = shutil.os.path.splitext(onnx_model_filename) debug_model_filename, _ = shutil.os.path.splitext(onnx_model_filename)
onnx.save(model, debug_model_filename + '.optimized_and_inffered.onnx') onnx.save(model, debug_model_filename + '.optimized_and_inffered.onnx')
# I/O instances
# onnx.save(model, '/tmp/export/optimized_and_inffered.onnx')
# I/O instances
onnx_graph = onnx_model.graph onnx_graph = onnx_model.graph
fluid_program = Program() fluid_program = Program()
fluid_writer = Writer() fluid_writer = Writer()
...@@ -114,8 +115,8 @@ def convert(onnx_model_filename, ...@@ -114,8 +115,8 @@ def convert(onnx_model_filename,
# op set conversion # op set conversion
# topo = 'backward' if embed_params else 'forward' # topo = 'backward' if embed_params else 'forward'
topo = 'forward' topo = 'forward'
for name, domain, op_type, inputs, outputs, attrs in graph_ops( for name, domain, op_type, inputs, outputs, attrs in graph_ops(onnx_graph,
onnx_graph, topo=topo): topo=topo):
logger.debug('translating op %s %s::%s ...', name, domain, op_type) logger.debug('translating op %s %s::%s ...', name, domain, op_type)
if domain == DEFAULT_OP_DOMAIN: if domain == DEFAULT_OP_DOMAIN:
domain = '' domain = ''
...@@ -140,6 +141,24 @@ def convert(onnx_model_filename, ...@@ -140,6 +141,24 @@ def convert(onnx_model_filename,
logger.info('%d ops in, %d ops out', len(onnx_graph.node), logger.info('%d ops in, %d ops out', len(onnx_graph.node),
len(fluid_program.op_descs)) len(fluid_program.op_descs))
# shape-inference
for name, value_info in graph_value_infos.items():
var_name = make_var_name(name)
fluid_program.VarTypeInfo(var_name, value_info,
remove_batch=False) # shape-infer only
bad_var_names = []
for var_name, var_desc in fluid_program.var_descs.items():
if not var_desc.type.lod_tensor.HasField('tensor'):
bad_var_names.append(var_name)
if len(bad_var_names) > 0:
logger.warning('type info not infered for var %s ...',
', '.join(bad_var_names[:5]))
logger.warning('this causes little problem for PaddlePaddle, '
'but Paddle Mobile may not infer correctly')
logger.warning(
'please consider adding option -d to invoke PaddlePaddle shape-inference'
)
# weight writer # weight writer
for name, weight in graph_weights(onnx_graph): for name, weight in graph_weights(onnx_graph):
graph_params.append(name) graph_params.append(name)
...@@ -173,8 +192,9 @@ def convert(onnx_model_filename, ...@@ -173,8 +192,9 @@ def convert(onnx_model_filename,
value_info = graph_value_infos[name] value_info = graph_value_infos[name]
assert value_info['external'] assert value_info['external']
external_inputs.append(name) external_inputs.append(name)
fluid_writer.emit_inputs( fluid_writer.emit_inputs(fluid_program,
fluid_program, external_inputs, graph_value_infos, external_inputs,
graph_value_infos,
remove_batch=False) # TODO: remove_batch=False) # TODO:
input_codes = fluid_program.codes input_codes = fluid_program.codes
fluid_program.codes = [] fluid_program.codes = []
...@@ -206,12 +226,13 @@ def convert(onnx_model_filename, ...@@ -206,12 +226,13 @@ def convert(onnx_model_filename,
fluid_writer.write_desc_file( fluid_writer.write_desc_file(
desc_filename, desc_filename,
op_descs=fluid_program.op_descs, op_descs=fluid_program.op_descs,
var_descs=fluid_program.var_descs, var_descs=list(fluid_program.var_descs.values()),
) )
logger.info('program saved to %s', desc_filename) logger.info('program saved to %s', desc_filename)
logger.info('conversion finished') logger.info('conversion finished')
if __name__ == '__main__': if __name__ == '__main__':
del convert del convert
...@@ -283,8 +304,7 @@ if __name__ == '__main__': ...@@ -283,8 +304,7 @@ if __name__ == '__main__':
pedantic = args.pedantic pedantic = args.pedantic
skip_version_conversion = args.skip_version_conversion skip_version_conversion = args.skip_version_conversion
convert( convert(model_filename,
model_filename,
save_dir, save_dir,
embed_params=embed_params, embed_params=embed_params,
onnx_opset_pedantic=pedantic, onnx_opset_pedantic=pedantic,
......
...@@ -28,30 +28,66 @@ _ATTRTYPE = _descriptor.EnumDescriptor( ...@@ -28,30 +28,66 @@ _ATTRTYPE = _descriptor.EnumDescriptor(
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
values=[ values=[
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(name='INT',
name='INT', index=0, number=0, options=None, type=None), index=0,
_descriptor.EnumValueDescriptor( number=0,
name='FLOAT', index=1, number=1, options=None, type=None), options=None,
_descriptor.EnumValueDescriptor( type=None),
name='STRING', index=2, number=2, options=None, type=None), _descriptor.EnumValueDescriptor(name='FLOAT',
_descriptor.EnumValueDescriptor( index=1,
name='INTS', index=3, number=3, options=None, type=None), number=1,
_descriptor.EnumValueDescriptor( options=None,
name='FLOATS', index=4, number=4, options=None, type=None), type=None),
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(name='STRING',
name='STRINGS', index=5, number=5, options=None, type=None), index=2,
_descriptor.EnumValueDescriptor( number=2,
name='BOOLEAN', index=6, number=6, options=None, type=None), options=None,
_descriptor.EnumValueDescriptor( type=None),
name='BOOLEANS', index=7, number=7, options=None, type=None), _descriptor.EnumValueDescriptor(name='INTS',
_descriptor.EnumValueDescriptor( index=3,
name='BLOCK', index=8, number=8, options=None, type=None), number=3,
_descriptor.EnumValueDescriptor( options=None,
name='LONG', index=9, number=9, options=None, type=None), type=None),
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(name='FLOATS',
name='BLOCKS', index=10, number=10, options=None, type=None), index=4,
_descriptor.EnumValueDescriptor( number=4,
name='LONGS', index=11, number=11, options=None, type=None), options=None,
type=None),
_descriptor.EnumValueDescriptor(name='STRINGS',
index=5,
number=5,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='BOOLEAN',
index=6,
number=6,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='BOOLEANS',
index=7,
number=7,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='BLOCK',
index=8,
number=8,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='LONG',
index=9,
number=9,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='BLOCKS',
index=10,
number=10,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='LONGS',
index=11,
number=11,
options=None,
type=None),
], ],
containing_type=None, containing_type=None,
options=None, options=None,
...@@ -80,53 +116,111 @@ _VARTYPE_TYPE = _descriptor.EnumDescriptor( ...@@ -80,53 +116,111 @@ _VARTYPE_TYPE = _descriptor.EnumDescriptor(
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
values=[ values=[
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(name='BOOL',
name='BOOL', index=0, number=0, options=None, type=None), index=0,
_descriptor.EnumValueDescriptor( number=0,
name='INT16', index=1, number=1, options=None, type=None), options=None,
_descriptor.EnumValueDescriptor( type=None),
name='INT32', index=2, number=2, options=None, type=None), _descriptor.EnumValueDescriptor(name='INT16',
_descriptor.EnumValueDescriptor( index=1,
name='INT64', index=3, number=3, options=None, type=None), number=1,
_descriptor.EnumValueDescriptor( options=None,
name='FP16', index=4, number=4, options=None, type=None), type=None),
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(name='INT32',
name='FP32', index=5, number=5, options=None, type=None), index=2,
_descriptor.EnumValueDescriptor( number=2,
name='FP64', index=6, number=6, options=None, type=None), options=None,
_descriptor.EnumValueDescriptor( type=None),
name='SIZE_T', index=7, number=19, options=None, type=None), _descriptor.EnumValueDescriptor(name='INT64',
_descriptor.EnumValueDescriptor( index=3,
name='UINT8', index=8, number=20, options=None, type=None), number=3,
_descriptor.EnumValueDescriptor( options=None,
name='INT8', index=9, number=21, options=None, type=None), type=None),
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(name='FP16',
name='LOD_TENSOR', index=10, number=7, options=None, type=None), index=4,
_descriptor.EnumValueDescriptor( number=4,
name='SELECTED_ROWS', index=11, number=8, options=None, type=None), options=None,
_descriptor.EnumValueDescriptor( type=None),
name='FEED_MINIBATCH', index=12, number=9, options=None, type=None), _descriptor.EnumValueDescriptor(name='FP32',
_descriptor.EnumValueDescriptor( index=5,
name='FETCH_LIST', index=13, number=10, options=None, type=None), number=5,
_descriptor.EnumValueDescriptor( options=None,
name='STEP_SCOPES', index=14, number=11, options=None, type=None), type=None),
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(name='FP64',
name='LOD_RANK_TABLE', index=15, number=12, options=None, index=6,
number=6,
options=None,
type=None), type=None),
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(name='SIZE_T',
name='LOD_TENSOR_ARRAY', index=7,
number=19,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='UINT8',
index=8,
number=20,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='INT8',
index=9,
number=21,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='LOD_TENSOR',
index=10,
number=7,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='SELECTED_ROWS',
index=11,
number=8,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='FEED_MINIBATCH',
index=12,
number=9,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='FETCH_LIST',
index=13,
number=10,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='STEP_SCOPES',
index=14,
number=11,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='LOD_RANK_TABLE',
index=15,
number=12,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='LOD_TENSOR_ARRAY',
index=16, index=16,
number=13, number=13,
options=None, options=None,
type=None), type=None),
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(name='PLACE_LIST',
name='PLACE_LIST', index=17, number=14, options=None, type=None), index=17,
_descriptor.EnumValueDescriptor( number=14,
name='READER', index=18, number=15, options=None, type=None), options=None,
_descriptor.EnumValueDescriptor( type=None),
name='RAW', index=19, number=17, options=None, type=None), _descriptor.EnumValueDescriptor(name='READER',
_descriptor.EnumValueDescriptor( index=18,
name='TUPLE', index=20, number=18, options=None, type=None), number=15,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='RAW',
index=19,
number=17,
options=None,
type=None),
_descriptor.EnumValueDescriptor(name='TUPLE',
index=20,
number=18,
options=None,
type=None),
], ],
containing_type=None, containing_type=None,
options=None, options=None,
...@@ -1480,8 +1574,7 @@ DESCRIPTOR.enum_types_by_name['AttrType'] = _ATTRTYPE ...@@ -1480,8 +1574,7 @@ DESCRIPTOR.enum_types_by_name['AttrType'] = _ATTRTYPE
Version = _reflection.GeneratedProtocolMessageType( Version = _reflection.GeneratedProtocolMessageType(
'Version', 'Version',
(_message.Message, ), (_message.Message, ),
dict( dict(DESCRIPTOR=_VERSION,
DESCRIPTOR=_VERSION,
__module__='framework_pb2' __module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.Version) # @@protoc_insertion_point(class_scope:paddle.framework.proto.Version)
)) ))
...@@ -1601,8 +1694,7 @@ _sym_db.RegisterMessage(VarType.Tuple) ...@@ -1601,8 +1694,7 @@ _sym_db.RegisterMessage(VarType.Tuple)
VarDesc = _reflection.GeneratedProtocolMessageType( VarDesc = _reflection.GeneratedProtocolMessageType(
'VarDesc', 'VarDesc',
(_message.Message, ), (_message.Message, ),
dict( dict(DESCRIPTOR=_VARDESC,
DESCRIPTOR=_VARDESC,
__module__='framework_pb2' __module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarDesc) # @@protoc_insertion_point(class_scope:paddle.framework.proto.VarDesc)
)) ))
......
...@@ -50,8 +50,7 @@ def print_pb_structure(message, loop_iterative=False, depth=0): ...@@ -50,8 +50,7 @@ def print_pb_structure(message, loop_iterative=False, depth=0):
if hasattr(message, 'DESCRIPTOR') and hasattr(message.DESCRIPTOR, 'fields'): if hasattr(message, 'DESCRIPTOR') and hasattr(message.DESCRIPTOR, 'fields'):
for field in message.DESCRIPTOR.fields: for field in message.DESCRIPTOR.fields:
print('\t' * depth + '-', field.name) print('\t' * depth + '-', field.name)
print_pb_structure( print_pb_structure(getattr(message, field.name),
getattr(message, field.name),
loop_iterative=loop_iterative, loop_iterative=loop_iterative,
depth=(depth + 1)) depth=(depth + 1))
...@@ -59,8 +58,9 @@ def print_pb_structure(message, loop_iterative=False, depth=0): ...@@ -59,8 +58,9 @@ def print_pb_structure(message, loop_iterative=False, depth=0):
message, '__len__'): message, '__len__'):
for idx, item in enumerate(message): for idx, item in enumerate(message):
print('\t' * depth + '-', idx) print('\t' * depth + '-', idx)
print_pb_structure( print_pb_structure(item,
item, loop_iterative=loop_iterative, depth=(depth + 1)) loop_iterative=loop_iterative,
depth=(depth + 1))
def build_value_refs(nodes): def build_value_refs(nodes):
...@@ -86,8 +86,9 @@ def get_attribute_value2(attr): ...@@ -86,8 +86,9 @@ def get_attribute_value2(attr):
if attr.type == onnx.AttributeProto.TENSOR: if attr.type == onnx.AttributeProto.TENSOR:
dtype = np.dtype(TENSOR_TYPE_TO_NP_TYPE[attr.t.data_type]) dtype = np.dtype(TENSOR_TYPE_TO_NP_TYPE[attr.t.data_type])
data = attr.t.raw_data data = attr.t.raw_data
value = np.frombuffer( value = np.frombuffer(data,
data, dtype=dtype, count=(len(data) // dtype.itemsize)) dtype=dtype,
count=(len(data) // dtype.itemsize))
elif attr.type == onnx.AttributeProto.STRING: elif attr.type == onnx.AttributeProto.STRING:
value = attr.s value = attr.s
value = value.decode() if isinstance(value, bytes) else value value = value.decode() if isinstance(value, bytes) else value
...@@ -208,6 +209,9 @@ def node_iter(nodes, indices=None): ...@@ -208,6 +209,9 @@ def node_iter(nodes, indices=None):
if name == '': if name == '':
name = 'op_' + str(index) name = 'op_' + str(index)
else: # make_op_name
for s in ' \\|/:': #
name = name.replace(s, '_')
if domain == '': if domain == '':
domain = DEFAULT_OP_DOMAIN domain = DEFAULT_OP_DOMAIN
...@@ -250,25 +254,25 @@ def inferred_model_value_info(model): ...@@ -250,25 +254,25 @@ def inferred_model_value_info(model):
graph = model.graph graph = model.graph
value_info = Dict() value_info = Dict()
for item in graph.value_info: for item in graph.value_info:
value_info[item.name] = dict( value_info[item.name] = {
dtype=tensor_dtype(item), 'dtype': tensor_dtype(item),
shape=tensor_shape(item), 'shape': tensor_shape(item),
external=False, 'external': False,
) }
for item in graph.input: for item in graph.input:
assert item.name not in value_info assert item.name not in value_info
value_info[item.name] = dict( value_info[item.name] = {
dtype=tensor_dtype(item), 'dtype': tensor_dtype(item),
shape=tensor_shape(item), 'shape': tensor_shape(item),
external=True, 'external': True,
) }
for item in graph.output: for item in graph.output:
# assert item.name not in value_info, 'bypass-model not supported' # assert item.name not in value_info, 'bypass-model not supported'
value_info[item.name] = dict( value_info[item.name] = {
dtype=tensor_dtype(item), 'dtype': tensor_dtype(item),
shape=tensor_shape(item), 'shape': tensor_shape(item),
external=True, 'external': True,
) }
return value_info return value_info
...@@ -307,7 +311,7 @@ def optimize_model_skip_op_for_inference(model, op_list=None): ...@@ -307,7 +311,7 @@ def optimize_model_skip_op_for_inference(model, op_list=None):
skip ops can be bypassed for inference skip ops can be bypassed for inference
""" """
if op_list is None: if op_list is None:
op_list = ['Dropout'] op_list = ('Dropout', 'Identity')
nodes = model.graph.node nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes) input_refs, output_refs = build_value_refs(nodes)
...@@ -325,7 +329,7 @@ def optimize_model_skip_op_for_inference(model, op_list=None): ...@@ -325,7 +329,7 @@ def optimize_model_skip_op_for_inference(model, op_list=None):
if not (op_type in op_list): if not (op_type in op_list):
continue continue
if op_type in ['Dropout']: if op_type in ('Dropout', ):
input_name = node.input[0] input_name = node.input[0]
output_name = node.output[0] output_name = node.output[0]
elif not (len(node.input) == 1 and len(node.output) == 1): elif not (len(node.input) == 1 and len(node.output) == 1):
...@@ -406,7 +410,7 @@ def optimize_model_strip_initializer(model, keep_input_only=True): ...@@ -406,7 +410,7 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
def optimize_model_cast(model): def optimize_model_cast(model):
""" """
strip cascade and unecessary onnx::Cast strip cascade and unecessary onnx::Cast-9:
""" """
nodes = model.graph.node nodes = model.graph.node
...@@ -463,13 +467,13 @@ def optimize_model_cast(model): ...@@ -463,13 +467,13 @@ def optimize_model_cast(model):
def optimize_model_slice(model): def optimize_model_slice(model):
""" """
strip cascade and unecessary onnx::Slice strip cascade and unecessary onnx::Slice-1:9
""" """
nodes = model.graph.node nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes) input_refs, output_refs = build_value_refs(nodes)
def _build_slice_node_chain(node_idx): def build_slice_node_chain(node_idx):
chain = [] chain = []
while True: while True:
node = nodes[node_idx] node = nodes[node_idx]
...@@ -485,7 +489,7 @@ def optimize_model_slice(model): ...@@ -485,7 +489,7 @@ def optimize_model_slice(model):
node_idx = list(input_refs[output_name])[0] node_idx = list(input_refs[output_name])[0]
# axis: (start, end) # axis: (start, end)
def _merge_slice(slice_chain): def merge_slice(slice_chain):
merged_slice = dict() merged_slice = dict()
for slice_node_idx in slice_chain: for slice_node_idx in slice_chain:
node = nodes[slice_node_idx] node = nodes[slice_node_idx]
...@@ -508,14 +512,14 @@ def optimize_model_slice(model): ...@@ -508,14 +512,14 @@ def optimize_model_slice(model):
ret_nodes = ret.graph.node ret_nodes = ret.graph.node
nodes_to_remove = [] nodes_to_remove = []
for node_idx in range(len(nodes)): for node_idx in range(len(nodes)):
slice_chain = _build_slice_node_chain(node_idx) slice_chain = build_slice_node_chain(node_idx)
if len(slice_chain) == 0: if len(slice_chain) == 0:
continue continue
merged_slice = _merge_slice(slice_chain) merged_slice = merge_slice(slice_chain)
if len(merged_slice) > 0 and len(slice_chain) == 1: # no need to merge if len(merged_slice) > 0 and len(slice_chain) == 1: # no need to merge
continue continue
attrs = dict(axes=[], starts=[], ends=[]) attrs = {'axes': [], 'starts': [], 'ends': []}
for axis, (start, end) in merged_slice.items(): for axis, (start, end) in merged_slice.items():
attrs['axes'].append(axis) attrs['axes'].append(axis)
attrs['starts'].append(start) attrs['starts'].append(start)
......
...@@ -38,6 +38,7 @@ DEFAULT_OP_MAPPING_FIELD_VALUES[ ...@@ -38,6 +38,7 @@ DEFAULT_OP_MAPPING_FIELD_VALUES[
DEFAULT_OP_MAPPING_FIELD_VALUES[ DEFAULT_OP_MAPPING_FIELD_VALUES[
'OUTPUT_PERM'] = None # sampler: [idx_onnx_arg...] 'OUTPUT_PERM'] = None # sampler: [idx_onnx_arg...]
DEFAULT_OP_MAPPING_FIELD_VALUES['FILL_NAME_FIELD'] = True DEFAULT_OP_MAPPING_FIELD_VALUES['FILL_NAME_FIELD'] = True
DEFAULT_OP_MAPPING_VALUES = list(DEFAULT_OP_MAPPING_FIELD_VALUES.values())
DEFAULT_OP_MAPPING = { DEFAULT_OP_MAPPING = {
## nil ops ## ## nil ops ##
...@@ -145,24 +146,20 @@ DEFAULT_IOA_CONSTRAINTS = { ...@@ -145,24 +146,20 @@ DEFAULT_IOA_CONSTRAINTS = {
def _make_var_name(name): def _make_var_name(name):
""" """
make a valid variable name in Python code make a valid variable name in Python code and in filesystem
""" """
if name == '': if name == '':
return '_' return '_'
if name[0].isdigit(): if name[0].isdigit():
return 'var_' + name return 'var_' + name
for s in ' *?\\/-:': for s in ' \\|/:': #
name = name.replace(s, '_') name = name.replace(s, '_')
if name.startswith('_'): if name.startswith('_'):
name = 'var' + name name = 'var' + name
return name return name
#def _value_info_or_none(value_infos, val_name):
# return value_infos.get(val_name, None)
def _dtype(value_infos, val_name): def _dtype(value_infos, val_name):
return _np.dtype(value_infos[val_name]['dtype']) return _np.dtype(value_infos[val_name]['dtype'])
...@@ -204,7 +201,7 @@ def _const_weight_or_none(value_infos, val_name): ...@@ -204,7 +201,7 @@ def _const_weight_or_none(value_infos, val_name):
def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs): def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
info = DEFAULT_OP_MAPPING[op_type] info = DEFAULT_OP_MAPPING[op_type]
info.extend(list(DEFAULT_OP_MAPPING_FIELD_VALUES.values())[len(info):]) info.extend(DEFAULT_OP_MAPPING_VALUES[len(info):])
( (
fluid_op, fluid_op,
...@@ -295,7 +292,7 @@ def _zeros_like(prog, val_ref, val_out, value_infos): ...@@ -295,7 +292,7 @@ def _zeros_like(prog, val_ref, val_out, value_infos):
'Sub', 'Sub',
[val_ref, val_ref], [val_ref, val_ref],
[val_out], # val [val_out], # val
dict(axis=0), {'axis': 0},
value_infos, value_infos,
) )
...@@ -317,11 +314,11 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE ...@@ -317,11 +314,11 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
'Pad', 'Pad',
[val_name], [val_name],
[val_padded], # val [val_padded], # val
dict( {
mode='constant', 'mode': 'constant',
value=0., 'value': 0.,
pads=pads, 'pads': pads,
), },
value_infos=value_infos, value_infos=value_infos,
name=val_padded, name=val_padded,
) )
...@@ -372,14 +369,14 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''): ...@@ -372,14 +369,14 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''):
fluid_op, fluid_op,
([var_x], 'X'), ([var_x], 'X'),
([var_y] + ([var_indices] if has_indices else []), 'Out', 'Indices'), ([var_y] + ([var_indices] if has_indices else []), 'Out', 'Indices'),
dict( {
global_pooling=False, 'global_pooling': False,
adaptive=True, 'adaptive': True,
exclusive=True, 'exclusive': True,
require_index=has_indices, 'require_index': has_indices,
pooling_type=pool_type, 'pooling_type': pool_type,
ksize=pool_size, 'ksize': pool_size,
), },
) )
...@@ -419,12 +416,12 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''): ...@@ -419,12 +416,12 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
fluid_op, fluid_op,
([var_x], 'X'), ([var_x], 'X'),
([var_y], 'Out'), ([var_y], 'Out'),
dict( {
global_pooling=True, 'global_pooling': True,
adaptive=False, 'adaptive': False,
pooling_type=pool_type, 'pooling_type': pool_type,
ksize=[-1, -1], 'ksize': [-1, -1],
), },
) )
...@@ -481,17 +478,17 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''): ...@@ -481,17 +478,17 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
fluid_op, fluid_op,
([var_x], 'X'), ([var_x], 'X'),
([var_y] + ([var_indices] if has_indices else []), 'Out', 'Indices'), ([var_y] + ([var_indices] if has_indices else []), 'Out', 'Indices'),
dict( {
global_pooling=False, 'global_pooling': False,
adaptive=False, 'adaptive': False,
exclusive=True, 'exclusive': True,
require_index=has_indices, 'require_index': has_indices,
pooling_type=pool_type, 'pooling_type': pool_type,
ksize=pool_size, 'ksize': pool_size,
strides=strides, 'strides': strides,
paddings=paddings, 'paddings': paddings,
ceil_mode=ceil_mode, 'ceil_mode': ceil_mode,
), },
) )
...@@ -506,11 +503,11 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name): ...@@ -506,11 +503,11 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
# interpretation # interpretation
spatial_scale = attrs['spatial_scale'] # required spatial_scale = attrs['spatial_scale'] # required
pooled_height, pooled_width = attrs['pooled_shape'] # required pooled_height, pooled_width = attrs['pooled_shape'] # required
od_attrs = dict( od_attrs = {
pooled_height=pooled_height, 'pooled_height': pooled_height,
pooled_width=pooled_width, 'pooled_width': pooled_width,
spatial_scale=spatial_scale, 'spatial_scale': spatial_scale,
) }
feature_attr = '' feature_attr = ''
is_max_pool = fluid_op == 'roi_pool' is_max_pool = fluid_op == 'roi_pool'
if 'sampling_ratio' in attrs: # if 'sampling_ratio' in attrs: #
...@@ -606,11 +603,11 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''): ...@@ -606,11 +603,11 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
fluid_op, fluid_op,
([var_x], 'X'), ([var_x], 'X'),
([var_y], 'Out'), ([var_y], 'Out'),
dict( {
interp_method=mode, 'interp_method': mode,
out_h=out_shape_[0], 'out_h ': out_shape_[0],
out_w=out_shape_[1], 'out_w ': out_shape_[1],
), },
) )
...@@ -662,7 +659,7 @@ def AffineGrid(prog, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -662,7 +659,7 @@ def AffineGrid(prog, inputs, outputs, attrs, *args, name='', **kwargs):
fluid_op, fluid_op,
([var_theta], 'Theta'), ([var_theta], 'Theta'),
([var_grid], 'Output'), ([var_grid], 'Output'),
dict(output_shape=size), # f**k you API {'output_shape': size}, # f**k you API
) )
...@@ -747,16 +744,17 @@ def BatchNormalization(prog, ...@@ -747,16 +744,17 @@ def BatchNormalization(prog,
prog.VarDesc(var_saved_variance) prog.VarDesc(var_saved_variance)
prog.OpDesc( prog.OpDesc(
fluid_op, fluid_op,
([var_x, var_scale, var_b, var_mean, var_var], 'X', 'Scale', 'Bias', ([var_x, var_scale, var_b, var_mean, var_var
'Mean', 'Variance'), ], 'X', 'Scale', 'Bias', 'Mean', 'Variance'),
([var_y, var_mean, var_saved_mean, var_saved_variance, var_var], 'Y', ([var_y, var_mean, var_saved_mean, var_saved_variance, var_var
'MeanOut', 'SavedMean', 'SavedVariance', 'VarianceOut'), ], 'Y', 'MeanOut', 'SavedMean', 'SavedVariance', 'VarianceOut'),
dict( {
is_test=1, 'is_test': 1,
data_layout='NCHW', 'data_layout': 'NCHW',
use_global_stats=False, 'use_global_stats': False,
momentum=momentum, 'momentum': momentum,
epsilon=epsilon), 'epsilon': epsilon,
},
) )
...@@ -796,11 +794,12 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -796,11 +794,12 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
fluid_op, fluid_op,
([var_input], 'X'), ([var_input], 'X'),
([var_output], 'Out'), ([var_output], 'Out'),
dict( {
in_dtype=prog.Dtype(_dtype(value_infos, 'in_dtype': prog.Dtype(_dtype(value_infos,
val_input)), # holy, required val_input)), # holy, required
out_dtype=prog.Dtype(dtype), 'out_dtype': prog.Dtype(dtype),
)) },
)
def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs): def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs):
...@@ -834,7 +833,7 @@ def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -834,7 +833,7 @@ def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs):
fluid_op, fluid_op,
(var_inps, *(['X'] * len(var_inps))), (var_inps, *(['X'] * len(var_inps))),
([var_concat_result], 'Out'), ([var_concat_result], 'Out'),
dict(axis=axis), {'axis': axis},
) )
...@@ -886,11 +885,11 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -886,11 +885,11 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
fluid_op, fluid_op,
([], ), ([], ),
([var_output], 'Out'), ([var_output], 'Out'),
dict( {
shape=shape, 'shape': shape,
dtype=prog.Dtype(dtype), 'dtype': prog.Dtype(dtype),
value=value, 'value': value,
), },
) )
else: # list parameter -> const_value else: # list parameter -> const_value
prog.Code('# {} = {} # passed directly as literal'.format( prog.Code('# {} = {} # passed directly as literal'.format(
...@@ -917,7 +916,7 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -917,7 +916,7 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
'this is not supported') 'this is not supported')
dtype = attrs['value'].dtype dtype = attrs['value'].dtype
attrs = attrs.copy() attrs = attrs.copy()
attrs.update(dict(shape=shape, dtype=dtype)) # pass const attrs.update({'shape': shape, 'dtype': dtype}) # pass const
prog.Code('# shape:{}={} # const as literal'.format(var_shape, shape)) prog.Code('# shape:{}={} # const as literal'.format(var_shape, shape))
prog.Op( prog.Op(
...@@ -1015,12 +1014,13 @@ def Conv(prog, ...@@ -1015,12 +1014,13 @@ def Conv(prog,
fluid_op, fluid_op,
([var_x, var_w], 'Input', 'Filter'), # , 'Bias', 'ResidualData' ([var_x, var_w], 'Input', 'Filter'), # , 'Bias', 'ResidualData'
([var_conv if has_bias else var_y], 'Output'), ([var_conv if has_bias else var_y], 'Output'),
dict( {
strides=strides, 'strides': strides,
paddings=paddings, 'paddings': paddings,
dilations=dilations, 'dilations': dilations,
groups=num_groups, 'groups': num_groups,
)) },
)
if has_bias: if has_bias:
prog.VarDesc(var_conv) prog.VarDesc(var_conv)
prog.IntermediateOp( prog.IntermediateOp(
...@@ -1028,7 +1028,7 @@ def Conv(prog, ...@@ -1028,7 +1028,7 @@ def Conv(prog,
'Add', 'Add',
[var_conv, var_b], # [var_conv, var_b], #
[val_y], [val_y],
dict(axis=1), {'axis': 1},
value_infos=value_infos, value_infos=value_infos,
name=(name + '.bias'), name=(name + '.bias'),
) )
...@@ -1125,13 +1125,14 @@ def ConvTranspose(prog, ...@@ -1125,13 +1125,14 @@ def ConvTranspose(prog,
fluid_op, fluid_op,
([var_x, var_w], 'Input', 'Filter'), # , 'Bias', 'ResidualData' ([var_x, var_w], 'Input', 'Filter'), # , 'Bias', 'ResidualData'
([var_conv if has_bias else var_y], 'Output'), ([var_conv if has_bias else var_y], 'Output'),
dict( {
strides=strides, 'strides': strides,
paddings=paddings, 'paddings': paddings,
dilations=dilations, 'dilations': dilations,
# output_size=output_size, # 'output_size': output_size,
groups=num_groups, 'groups': num_groups,
)) },
)
if has_bias: if has_bias:
prog.VarDesc(var_conv) prog.VarDesc(var_conv)
prog.IntermediateOp( prog.IntermediateOp(
...@@ -1139,7 +1140,7 @@ def ConvTranspose(prog, ...@@ -1139,7 +1140,7 @@ def ConvTranspose(prog,
'Add', 'Add',
[var_conv, var_b], # [var_conv, var_b], #
[val_y], [val_y],
dict(axis=1), {'axis': 1},
value_infos=value_infos, value_infos=value_infos,
name=(name + '.bias'), name=(name + '.bias'),
) )
...@@ -1184,19 +1185,19 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1184,19 +1185,19 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'MatMul', 'MatMul',
[val_a, val_b], [val_a, val_b],
[val_mm], # val [val_mm], # val
dict( {
transpose_x=trans_a, 'transpose_x': trans_a,
transpose_y=trans_b, 'transpose_y': trans_b,
alpha=alpha, 'alpha': alpha,
), },
value_infos=value_infos, value_infos=value_infos,
name=val_mm, name=val_mm,
) )
prog.op_descs[-1].attrs.extend( prog.op_descs[-1].attrs.extend(
prog.OpDescAttrs(dict( prog.OpDescAttrs({
transpose_X=trans_a, 'transpose_X': trans_a,
transpose_Y=trans_b, 'transpose_Y': trans_b,
))) # f**k you API })) # f**k you API
if beta != 0: if beta != 0:
if beta == 1.: # exactly if beta == 1.: # exactly
prog.Op( prog.Op(
...@@ -1204,7 +1205,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1204,7 +1205,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'Add', 'Add',
[val_mm, val_c], [val_mm, val_c],
[val_y], # val [val_y], # val
dict(axis=1), {'axis': 1},
value_infos=value_infos, value_infos=value_infos,
name=(name + '_beta'), name=(name + '_beta'),
) )
...@@ -1226,7 +1227,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1226,7 +1227,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'Constant', 'Constant',
[], [],
[val_beta], # val [val_beta], # val
dict(value=beta), {'value': beta},
value_infos=value_infos, value_infos=value_infos,
name=val_beta, name=val_beta,
) )
...@@ -1244,7 +1245,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1244,7 +1245,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'Add', 'Add',
[val_mm, val_vm], [val_mm, val_vm],
[val_y], # val [val_y], # val
dict(axis=1), {'axis': 1},
name=(name + '_bias'), name=(name + '_bias'),
) )
...@@ -1261,8 +1262,13 @@ def GlobalAveragePool(prog, ...@@ -1261,8 +1262,13 @@ def GlobalAveragePool(prog,
onnx::GlobalAveragePool-1: onnx::GlobalAveragePool-1:
""" """
return _global_pool( return _global_pool(prog,
prog, 'avg', inputs, outputs, attrs, value_infos, name=name) 'avg',
inputs,
outputs,
attrs,
value_infos,
name=name)
def GlobalMaxPool(prog, def GlobalMaxPool(prog,
...@@ -1277,60 +1283,13 @@ def GlobalMaxPool(prog, ...@@ -1277,60 +1283,13 @@ def GlobalMaxPool(prog,
onnx::GlobalMaxPool-1: onnx::GlobalMaxPool-1:
""" """
return _global_pool( return _global_pool(prog,
prog, 'max', inputs, outputs, attrs, value_infos, name=name) 'max',
inputs,
outputs,
#def LRN( attrs,
# prog, inputs, outputs, attrs, value_infos, name, # name required value_infos,
# *args, **kwargs): name=name)
# """
# onnx::LRN-1:
# """
#
# # I/O
# val_x, = inputs
# val_y, = outputs
# var_x = _make_var_name(val_x)
# var_y = _make_var_name(val_y)
#
# # interpretation
# fluid_op = 'lrn'
# size = attrs['size'] # required
# alpha = attrs.get('alpha', 0.0001) # optional
# beta = attrs.get('beta', 0.75) # optional
# bias = attrs.get('bias', 1.0) # optional
# name_attr = ', name={}'.format(repr(name)) if name else ''
#
# # generation
# prog.Code('{} = layers.{}({}'
# ', n={}'
# ', k={}'
# ', alpha={}'
# ', beta={}'
# '{})'
# .format(var_y,
# fluid_op,
# var_x,
# # attrs
# size,
# bias,
# alpha,
# beta,
# name_attr,
# ))
# var_mid = name + '.mid' # hidden variable
# prog.VarDesc(var_y)
# prog.VarDesc(var_mid)
# prog.OpDesc(fluid_op,
# ([var_x], 'X'),
# ([var_y, var_mid], 'Out', 'MidOut'),
# dict(n=size,
# k=bias,
# alpha=alpha,
# beta=beta,
# ),
# )
def MaxPool(prog, inputs, outputs, attrs, value_infos, name='', *args, def MaxPool(prog, inputs, outputs, attrs, value_infos, name='', *args,
...@@ -1375,7 +1334,7 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): ...@@ -1375,7 +1334,7 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
assume_pad2d |= data_shape and len(data_shape) == 4 # NCHW assume_pad2d |= data_shape and len(data_shape) == 4 # NCHW
if output_shape: if output_shape:
assume_pad2d |= output_shape and len(output_shape) == 4 # NCHW assume_pad2d |= output_shape and len(output_shape) == 4 # NCHW
od_attrs = dict(pad_value=value) od_attrs = {'pad_value': value}
if assume_pad2d: if assume_pad2d:
fluid_op = 'pad2d' fluid_op = 'pad2d'
pad2d_attr = ', mode={}, data_format="NCHW"'.format(repr(mode)) pad2d_attr = ', mode={}, data_format="NCHW"'.format(repr(mode))
...@@ -1434,11 +1393,20 @@ def PRelu(prog, ...@@ -1434,11 +1393,20 @@ def PRelu(prog,
var_y = _make_var_name(val_y) var_y = _make_var_name(val_y)
# interpretation # interpretation
mode = 'channel'
slope_shape = _shape_or_none(value_infos, val_slope)
if slope_shape is not None:
if len(slope_shape) == 0:
mode = 'all'
elif len(slope_shape) >= 2:
if slope_shape[1] != _np.product(
slope_shape): # not channel broadcasting
mode = 'element'
fluid_op = 'prelu' fluid_op = 'prelu'
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params: if embed_params:
assert name != '' assert name != ''
var_slope = '{}.w_0'.format(val_slope) var_slope = name + '.w_0'
value_infos[val_slope].setdefault('embeded_as', []).append(var_slope) value_infos[val_slope].setdefault('embeded_as', []).append(var_slope)
param_attr = '' param_attr = ''
else: else:
...@@ -1446,21 +1414,23 @@ def PRelu(prog, ...@@ -1446,21 +1414,23 @@ def PRelu(prog,
param_attr = ', param_attr={}'.format(repr(var_slope)) param_attr = ', param_attr={}'.format(repr(var_slope))
# generation # generation
prog.Code('{} = layers.{}({}, mode="all"' prog.Code('{} = layers.{}({}'
', mode={}'
'{}{})'.format( '{}{})'.format(
var_y, var_y,
fluid_op, fluid_op,
var_x, var_x,
# attrs # attrs
repr(mode),
param_attr, param_attr,
name_attr, name_attr,
)) ))
prog.VarDesc(var_y) prog.VarDesc(var_y)
prog.OpDesc( prog.OpDesc(
fluid_op, fluid_op,
([var_x], 'X'), ([var_x, var_slope], 'X', 'Alpha'),
([var_y], 'Out'), ([var_y], 'Out'),
dict(mode='all'), {'mode': mode},
) )
...@@ -1524,7 +1494,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1524,7 +1494,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'Cast', 'Cast',
[val_shape], [val_shape],
[val_shape_int32], # var [val_shape_int32], # var
dict(to=_np.dtype('int32')), # use np.dtype {'to': _np.dtype('int32')}, # use np.dtype
value_infos=value_infos, value_infos=value_infos,
name=(name + '_cast'), name=(name + '_cast'),
) )
...@@ -1549,14 +1519,14 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1549,14 +1519,14 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
fluid_op, fluid_op,
([var_data], 'X'), ([var_data], 'X'),
([var_reshaped, var_xshape], 'Out', 'XShape'), ([var_reshaped, var_xshape], 'Out', 'XShape'),
dict(shape=shape), {'shape': shape},
) )
else: else:
prog.OpDesc( prog.OpDesc(
fluid_op, fluid_op,
([var_data, var_shape_int32], 'X', 'Shape'), ([var_data, var_shape_int32], 'X', 'Shape'),
([var_reshaped, var_xshape], 'Out', 'XShape'), ([var_reshaped, var_xshape], 'Out', 'XShape'),
dict(shape=shape), {'shape': shape},
) )
...@@ -1659,11 +1629,11 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1659,11 +1629,11 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
fluid_op, fluid_op,
([var_data], 'Input'), ([var_data], 'Input'),
([var_output], 'Out'), ([var_output], 'Out'),
dict( {
axes=axes, 'axes': axes,
starts=starts, 'starts': starts,
ends=ends, 'ends': ends,
), },
) )
...@@ -1701,10 +1671,10 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -1701,10 +1671,10 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
fluid_op, fluid_op,
(var_input, 'X'), (var_input, 'X'),
([var_outs], *(['Out'] * len(var_outs))), ([var_outs], *(['Out'] * len(var_outs))),
dict( {
axis=axis, 'axis': axis,
sections=split, 'sections': split,
), },
) )
...@@ -1773,7 +1743,7 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): ...@@ -1773,7 +1743,7 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
fluid_op, fluid_op,
([var_input], 'X'), ([var_input], 'X'),
([var_output], 'Out'), ([var_output], 'Out'),
dict(expand_times=repeats), {'expand_times': repeats},
) )
...@@ -1812,7 +1782,7 @@ def Transpose(prog, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -1812,7 +1782,7 @@ def Transpose(prog, inputs, outputs, attrs, *args, name='', **kwargs):
fluid_op, fluid_op,
([var_data], 'X'), ([var_data], 'X'),
([var_transposed, var_xshape], 'Out', 'XShape'), ([var_transposed, var_xshape], 'Out', 'XShape'),
dict(axis=perm), # f**k you API {'axis': perm}, # f**k you API
) )
...@@ -1902,8 +1872,7 @@ if __name__ == '__main__': ...@@ -1902,8 +1872,7 @@ if __name__ == '__main__':
['input'], ['input'],
['output'], ['output'],
dict(to=2), # TensorProto.UINT8 dict(to=2), # TensorProto.UINT8
dict( dict(input=dict(shape=(2, 3), dtype=np.float32),
input=dict(shape=(2, 3), dtype=np.float32),
output=dict(shape=(2, 3), dtype=np.uint8)), output=dict(shape=(2, 3), dtype=np.uint8)),
) )
logger.info('Cast program:\n%s', prog) logger.info('Cast program:\n%s', prog)
...@@ -2101,8 +2070,7 @@ if __name__ == '__main__': ...@@ -2101,8 +2070,7 @@ if __name__ == '__main__':
logger.info('Less program:\n%s', prog) logger.info('Less program:\n%s', prog)
prog = Program() prog = Program()
_default( _default(prog,
prog,
'MatMul', ['A', 'B'], ['Y'], 'MatMul', ['A', 'B'], ['Y'],
dict(), dict(),
dict(Y=dict(shape=(2, 8), dtype=np.float32)), dict(Y=dict(shape=(2, 8), dtype=np.float32)),
...@@ -2168,11 +2136,9 @@ if __name__ == '__main__': ...@@ -2168,11 +2136,9 @@ if __name__ == '__main__':
logger.info('PRelu program:\n%s', prog) logger.info('PRelu program:\n%s', prog)
prog = Program() prog = Program()
Tile( Tile(prog, ['input', 'repeats'], ['output'],
prog, ['input', 'repeats'], ['output'],
dict(), dict(),
dict( dict(repeats=dict(const_value=[1, 2]),
repeats=dict(const_value=[1, 2]),
output=dict(shape=(2, 2, 4), dtype=np.float32)), output=dict(shape=(2, 2, 4), dtype=np.float32)),
name='Tile') name='Tile')
logger.info('Tile program:\n%s', prog) logger.info('Tile program:\n%s', prog)
...@@ -12,25 +12,25 @@ import torch ...@@ -12,25 +12,25 @@ import torch
from collections import OrderedDict as Dict from collections import OrderedDict as Dict
def _ensure_list(obj): def ensure_list(obj):
if isinstance(obj, (list, set, tuple)): if isinstance(obj, (list, tuple, set)):
return list(obj) return list(obj)
return [obj] return [obj]
def _ensure_tuple(obj): def ensure_tuple(obj):
if isinstance(obj, (list, set, tuple)): if isinstance(obj, (tuple, list, set)):
return tuple(obj) return tuple(obj)
return (obj, ) return (obj, )
def _flatten_list(obj, out=None): def flatten_list(obj, out=None):
assert isinstance(obj, list) assert isinstance(obj, list)
if out is None: if out is None:
out = type(obj)() out = type(obj)()
for item in obj: for item in obj:
if isinstance(item, list): if isinstance(item, list):
_flatten_list(item, out) flatten_list(item, out)
else: else:
out.append(item) out.append(item)
return out return out
...@@ -41,7 +41,7 @@ def export_data(state_dict, prefix=''): ...@@ -41,7 +41,7 @@ def export_data(state_dict, prefix=''):
export binary data with meta text for raw C++ inference engines export binary data with meta text for raw C++ inference engines
""" """
def _str(obj): def str_(obj):
if isinstance(obj, (tuple, list)): if isinstance(obj, (tuple, list)):
return str(obj)[1:-1].replace(' ', '') return str(obj)[1:-1].replace(' ', '')
return str(obj) return str(obj)
...@@ -52,14 +52,14 @@ def export_data(state_dict, prefix=''): ...@@ -52,14 +52,14 @@ def export_data(state_dict, prefix=''):
data = None data = None
if torch and torch.is_tensor(value): if torch and torch.is_tensor(value):
data = value.data.cpu().numpy() data = value.data.cpu().numpy()
elif np and isinstance(value, np.ndarray): elif isinstance(value, np.ndarray):
data = value data = value
if data is not None: if data is not None:
data.tofile('{}{}.bin'.format(prefix_, key)) data.tofile('{}{}.bin'.format(prefix_, key))
fp.write('{}.dtype={}\n'.format(key, _str(data.dtype.name))) fp.write('{}.dtype={}\n'.format(key, str_(data.dtype.name)))
fp.write('{}.shape={}\n'.format(key, _str(data.shape))) fp.write('{}.shape={}\n'.format(key, str_(data.shape)))
else: else:
fp.write('{}={}\n'.format(key, _str(value))) fp.write('{}={}\n'.format(key, str_(value)))
fp.close() fp.close()
...@@ -75,43 +75,42 @@ def export_onnx_with_validation(model, ...@@ -75,43 +75,42 @@ def export_onnx_with_validation(model,
export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
""" """
is_list_or_tuple = lambda x: isinstance(x, (list, tuple)) is_tuple_or_list = lambda x: isinstance(x, (tuple, list))
def _tensors_to_arrays(tensors): def tensors_to_arrays(tensors):
if torch.is_tensor(tensors): if torch.is_tensor(tensors):
return tensors.data.cpu().numpy() return tensors.data.cpu().numpy()
arrays = [] arrays = []
for tensor in tensors: for tensor in tensors:
arrays.append(_tensors_to_arrays(tensor)) arrays.append(tensors_to_arrays(tensor))
return arrays return arrays
def _zip_dict(keys, values): def zip_dict(keys, values):
ret = Dict() ret = Dict()
for idx, (key, value) in enumerate(zip(keys, values)): for idx, (key, value) in enumerate(zip(keys, values)):
is_key_list = is_list_or_tuple(key) is_key_list = is_tuple_or_list(key)
is_value_list = is_list_or_tuple(value) is_value_list = is_tuple_or_list(value)
assert is_key_list == is_value_list, 'keys and values mismatch' assert is_key_list == is_value_list, 'keys and values mismatch'
if is_value_list: if is_value_list:
ret[str(idx)] = _zip_dict(key, value) ret[str(idx)] = zip_dict(key, value)
else: else:
ret[key] = value ret[key] = value
return ret return ret
torch_inputs = _ensure_tuple(inputs) # WORKAROUND: for torch.onnx torch_inputs = ensure_tuple(inputs) # WORKAROUND: for torch.onnx
outputs = torch.onnx.export( outputs = torch.onnx.export(model,
model,
torch_inputs, torch_inputs,
export_basepath + '.onnx', export_basepath + '.onnx',
input_names=_flatten_list(input_names), input_names=flatten_list(input_names),
output_names=_flatten_list(output_names), output_names=flatten_list(output_names),
*args, *args,
**kwargs) **kwargs)
if outputs is None: # WORKAROUND: for torch.onnx if outputs is None: # WORKAROUND: for torch.onnx
outputs = model(*inputs) outputs = model(*inputs)
torch_outputs = _ensure_tuple(outputs) torch_outputs = ensure_tuple(outputs)
inputs = _zip_dict(input_names, _tensors_to_arrays(torch_inputs)) inputs = zip_dict(input_names, tensors_to_arrays(torch_inputs))
outputs = _zip_dict(output_names, _tensors_to_arrays(torch_outputs)) outputs = zip_dict(output_names, tensors_to_arrays(torch_outputs))
if use_npz: if use_npz:
np.savez(export_basepath + '.npz', inputs=inputs, outputs=outputs) np.savez(export_basepath + '.npz', inputs=inputs, outputs=outputs)
else: else:
......
...@@ -9,22 +9,21 @@ Created on Fri Mar 22 12:17:19 2019 ...@@ -9,22 +9,21 @@ Created on Fri Mar 22 12:17:19 2019
import importlib, logging, os, sys import importlib, logging, os, sys
def _flatten_dict(obj, out=None): def flatten_dict(obj, out=None):
assert isinstance(obj, dict) assert isinstance(obj, dict)
if out is None: if out is None:
out = type(obj)() out = type(obj)()
for key, value in obj.items(): for key, value in obj.items():
if isinstance(value, dict): if isinstance(value, dict):
_flatten_dict(value, out) flatten_dict(value, out)
else: else:
assert key not in out assert key not in out
out[key] = value out[key] = value
return out return out
def _ensure_list(obj): def ensure_list(obj):
for cls in [list, set, tuple]: if isinstance(obj, (list, tuple, set)):
if isinstance(obj, cls):
return list(obj) return list(obj)
return [obj] return [obj]
...@@ -33,7 +32,7 @@ def validate(fluid_model_filename, ...@@ -33,7 +32,7 @@ def validate(fluid_model_filename,
golden_data_filename, golden_data_filename,
model_func_name='inference', model_func_name='inference',
atol=1e-3, atol=1e-3,
rtol=1e-4, rtol=1e-3,
save_inference_model=False, save_inference_model=False,
**kwargs): **kwargs):
""" """
...@@ -56,8 +55,8 @@ def validate(fluid_model_filename, ...@@ -56,8 +55,8 @@ def validate(fluid_model_filename,
prog, _, var_outs = fluid.io.load_inference_model(fluid_model_dir, exe) prog, _, var_outs = fluid.io.load_inference_model(fluid_model_dir, exe)
out_names = var_outs # HINT: pass var if fetch ops already created out_names = var_outs # HINT: pass var if fetch ops already created
logger.info('model load passed') logger.info('model load passed')
elif basename.endswith('.py'): # is python code elif basename.endswith('.py'): # is Python code
logger.debug('using python code file %s', basename) logger.debug('using code file %s', basename)
module_name, _ = os.path.splitext(basename) module_name, _ = os.path.splitext(basename)
sys_path = sys.path.copy() sys_path = sys.path.copy()
sys.path.append(fluid_model_dir) sys.path.append(fluid_model_dir)
...@@ -73,14 +72,15 @@ def validate(fluid_model_filename, ...@@ -73,14 +72,15 @@ def validate(fluid_model_filename,
func) func)
var_outs = func() var_outs = func()
var_outs = _ensure_list(var_outs) var_outs = ensure_list(var_outs)
out_names = [var.name for var in var_outs out_names = [var.name for var in var_outs
] # HINT: pass string to create fetch ops ] # HINT: pass string to create fetch ops
logger.info('import passed') logger.info('import passed')
prog = fluid.default_main_program() prog = fluid.default_main_program()
fluid.io.load_persistables( fluid.io.load_persistables(executor=exe,
executor=exe, dirname=fluid_model_dir, main_program=prog) dirname=fluid_model_dir,
main_program=prog)
logger.info('weight load passed') logger.info('weight load passed')
else: else:
raise ValueError('unsupported Paddle fluid model filename') raise ValueError('unsupported Paddle fluid model filename')
...@@ -95,15 +95,14 @@ def validate(fluid_model_filename, ...@@ -95,15 +95,14 @@ def validate(fluid_model_filename,
test_data = np.load(golden_data_filename, encoding='bytes').tolist() test_data = np.load(golden_data_filename, encoding='bytes').tolist()
input_data = test_data['inputs'] input_data = test_data['inputs']
output_data = test_data['outputs'] output_data = test_data['outputs']
input_data = _flatten_dict(input_data) input_data = flatten_dict(input_data)
output_data = _flatten_dict(output_data) output_data = flatten_dict(output_data)
logger.info('found %d I/O golden data, starting test ...', logger.info('found %d I/O golden data, starting test ...',
len(input_data) + len(output_data)) len(input_data) + len(output_data))
# DEBUG: reload test for python code # DEBUG: reload test for Python code
if basename.endswith('.py') and save_inference_model: if basename.endswith('.py') and save_inference_model:
fluid.io.save_inference_model( fluid.io.save_inference_model(fluid_model_dir,
fluid_model_dir,
input_data.keys(), input_data.keys(),
var_outs, var_outs,
exe, exe,
...@@ -122,8 +121,7 @@ def validate(fluid_model_filename, ...@@ -122,8 +121,7 @@ def validate(fluid_model_filename,
for (name, truth), output in zip(output_data.items(), outputs): for (name, truth), output in zip(output_data.items(), outputs):
logger.info('testing output {} ...'.format(name)) logger.info('testing output {} ...'.format(name))
try: try:
np.testing.assert_allclose( np.testing.assert_allclose(output,
output,
truth, truth,
rtol=rtol, rtol=rtol,
atol=atol, atol=atol,
...@@ -174,7 +172,7 @@ if __name__ == '__main__': ...@@ -174,7 +172,7 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
'--rtol', '--rtol',
type=float, type=float,
default=1e-4, default=1e-2,
help='assertion relative tolerance for validation', help='assertion relative tolerance for validation',
) )
args = parser.parse_args() args = parser.parse_args()
...@@ -188,8 +186,7 @@ if __name__ == '__main__': ...@@ -188,8 +186,7 @@ if __name__ == '__main__':
golden_data_filename = args.test_data golden_data_filename = args.test_data
atol, rtol = args.atol, args.rtol atol, rtol = args.atol, args.rtol
validate( validate(fluid_model_filename,
fluid_model_filename,
golden_data_filename, golden_data_filename,
atol=atol, atol=atol,
rtol=rtol, rtol=rtol,
......
...@@ -11,6 +11,8 @@ from __future__ import division ...@@ -11,6 +11,8 @@ from __future__ import division
import logging, os import logging, os
import numpy as np import numpy as np
from collections import OrderedDict as Dict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from . import symbolic from . import symbolic
...@@ -30,7 +32,7 @@ __all__ = [ ...@@ -30,7 +32,7 @@ __all__ = [
] ]
def _irepr(obj, to='_'): def irepr(obj, to='_'):
"""inline repr""" """inline repr"""
s = repr(obj) s = repr(obj)
...@@ -41,12 +43,12 @@ def _irepr(obj, to='_'): ...@@ -41,12 +43,12 @@ def _irepr(obj, to='_'):
return s return s
def _flatten_list(obj, out=None): def flatten_list(obj, out=None):
if out is None: if out is None:
out = type(obj)() out = type(obj)()
for item in obj: for item in obj:
if isinstance(item, list): if isinstance(item, list):
_flatten_list(item, out) flatten_list(item, out)
else: else:
out.append(item) out.append(item)
return out return out
...@@ -59,7 +61,7 @@ def make_attr_name(name): ...@@ -59,7 +61,7 @@ def make_attr_name(name):
if name == '': if name == '':
raise ValueError('name should not be empty') raise ValueError('name should not be empty')
for s in ' *?\\/-:': # for s in ' \\|/:': #
name = name.replace(s, '_') name = name.replace(s, '_')
if not name.startswith('_'): if not name.startswith('_'):
name = '_' + name name = '_' + name
...@@ -130,8 +132,8 @@ class Program(object): ...@@ -130,8 +132,8 @@ class Program(object):
od_attr.type = framework_pb2.STRING od_attr.type = framework_pb2.STRING
od_attr.s = value od_attr.s = value
elif isinstance(value, list): elif isinstance(value, list):
if len(value) > 0: if len(value) > 0: # TODO: test all items
if isinstance(value, if isinstance(value[0],
bool): # bool.mro() = [bool, int, object] bool): # bool.mro() = [bool, int, object]
od_attr.type = framework_pb2.BOOLEANS od_attr.type = framework_pb2.BOOLEANS
od_attr.bools.extend(value) od_attr.bools.extend(value)
...@@ -164,14 +166,15 @@ class Program(object): ...@@ -164,14 +166,15 @@ class Program(object):
self.code_mutable = True self.code_mutable = True
self.codes = [] self.codes = []
self.op_descs = [] self.op_descs = []
self.var_descs = [] self.var_descs = Dict()
def __repr__(self): def __repr__(self):
return ('Program(code mutable: {}) with:\n' return ('Program(code mutable: {}) with:\n'
'codes: {}\n' 'codes: {}\n'
'op_descs: {}\n' 'op_descs: {}\n'
'var_descs: {}\n').format(self.code_mutable, self.codes, 'var_descs: {}\n').format(self.code_mutable, self.codes,
self.op_descs, self.var_descs) self.op_descs,
list(self.var_descs.values()))
def Code(self, code): def Code(self, code):
""" """
...@@ -182,7 +185,7 @@ class Program(object): ...@@ -182,7 +185,7 @@ class Program(object):
self.codes.append(code) self.codes.append(code)
def OpDesc(self, def OpDesc(self,
name, op_type,
input_val_keys=None, input_val_keys=None,
output_val_keys=None, output_val_keys=None,
attrs=None): attrs=None):
...@@ -191,7 +194,7 @@ class Program(object): ...@@ -191,7 +194,7 @@ class Program(object):
""" """
desc = framework_pb2.OpDesc() desc = framework_pb2.OpDesc()
desc.type = name desc.type = op_type
if input_val_keys is not None: if input_val_keys is not None:
desc.inputs.extend(self.OpDescVars(*input_val_keys)) desc.inputs.extend(self.OpDescVars(*input_val_keys))
if output_val_keys is not None: if output_val_keys is not None:
...@@ -202,7 +205,7 @@ class Program(object): ...@@ -202,7 +205,7 @@ class Program(object):
return desc return desc
def VarDesc(self, def VarDesc(self,
name, var_name,
persistable=False, persistable=False,
value_info=None, value_info=None,
remove_batch=None): remove_batch=None):
...@@ -210,24 +213,15 @@ class Program(object): ...@@ -210,24 +213,15 @@ class Program(object):
add VarDesc, add VarDesc,
""" """
assert var_name not in self.var_descs, 'var naming conflicted'
var_desc = framework_pb2.VarDesc() var_desc = framework_pb2.VarDesc()
var_desc.name = name var_desc.name = var_name
var_desc.persistable = persistable var_desc.persistable = persistable
var_desc.type.type = framework_pb2.VarType.LOD_TENSOR var_desc.type.type = framework_pb2.VarType.LOD_TENSOR
self.var_descs[var_name] = var_desc
if value_info and 'dtype' in value_info: if value_info:
tensor_desc = var_desc.type.lod_tensor.tensor self.VarTypeInfo(var_name, value_info, remove_batch=remove_batch)
tensor_desc.data_type = self.Dtype(value_info['dtype']) # required
if 'shape' in value_info:
tensor_desc.dims.extend(value_info['shape'])
if len(value_info['shape']) > 0: # skip scalars
if remove_batch is None:
remove_batch = value_info.get('remove_batch',
not persistable)
if remove_batch:
tensor_desc.dims[0] = -1
self.var_descs.append(var_desc)
def Op(self, domain, op_type, *args, **kwargs): def Op(self, domain, op_type, *args, **kwargs):
""" """
...@@ -261,13 +255,40 @@ class Program(object): ...@@ -261,13 +255,40 @@ class Program(object):
else: else:
self.code_mutable = code_mutable self.code_mutable = code_mutable
def VarTypeInfo(self, var_name, value_info, remove_batch=None):
"""
set value_info for var
"""
if var_name not in self.var_descs:
return
dtype = value_info.get('dtype', None)
if dtype is None:
return
var_desc = self.var_descs[var_name]
tensor_desc = var_desc.type.lod_tensor.tensor
tensor_desc.data_type = self.Dtype(dtype) # required
shape = value_info.get('shape', None)
if shape is not None:
tensor_desc.dims.extend(shape)
if len(shape) > 0: # skip scalars
if remove_batch is None:
remove_batch = value_info.get('remove_batch',
False) #not persistable)
if remove_batch:
tensor_desc.dims[0] = -1
class Writer(object): class Writer(object):
""" """
fluid code and desc writter fluid code and desc writter
""" """
CODE_INDENT = ' ' * 4 # CODE_INDENT = ' ' * 4
CODE_INDENT = '\t'
@staticmethod @staticmethod
def header_code(func_name, info=''): def header_code(func_name, info=''):
...@@ -275,7 +296,7 @@ class Writer(object): ...@@ -275,7 +296,7 @@ class Writer(object):
Python header codes Python header codes
""" """
codes = list() codes = []
codes.append('"""') codes.append('"""')
codes.append('This code is generated by onnx2fluid.') codes.append('This code is generated by onnx2fluid.')
codes.append('{}'.format(info)) codes.append('{}'.format(info))
...@@ -299,9 +320,8 @@ class Writer(object): ...@@ -299,9 +320,8 @@ class Writer(object):
prog.Code('# {}, {}::{}: {} -> {}, {}'.format(name, domain, op_type, prog.Code('# {}, {}::{}: {} -> {}, {}'.format(name, domain, op_type,
inputs, outputs, inputs, outputs,
_irepr(attrs, to=', '))) irepr(attrs, to=', ')))
prog.Op( prog.Op(domain,
domain,
op_type, op_type,
inputs, inputs,
outputs, outputs,
...@@ -367,10 +387,11 @@ class Writer(object): ...@@ -367,10 +387,11 @@ class Writer(object):
'feed', 'feed',
(['feed'], 'X'), (['feed'], 'X'),
([var_name], 'Out'), ([var_name], 'Out'),
dict(col=idx), {'col': idx},
) )
prog.VarDesc( prog.VarDesc(var_name,
var_name, value_info=value_info, remove_batch=remove_batch) value_info=value_info,
remove_batch=remove_batch)
@staticmethod @staticmethod
def emit_outputs(prog, names): #, value_infos def emit_outputs(prog, names): #, value_infos
...@@ -387,7 +408,7 @@ class Writer(object): ...@@ -387,7 +408,7 @@ class Writer(object):
'fetch', 'fetch',
([var_name], 'X'), ([var_name], 'X'),
(['fetch'], 'Out'), (['fetch'], 'Out'),
dict(col=idx), {'col': idx},
) )
# var is emitted over ops # var is emitted over ops
prog.Code(code) prog.Code(code)
...@@ -398,7 +419,7 @@ class Writer(object): ...@@ -398,7 +419,7 @@ class Writer(object):
flatten codes in program flatten codes in program
""" """
for code in _flatten_list(others): for code in flatten_list(others):
codes.append(Writer.CODE_INDENT * indent + code) codes.append(Writer.CODE_INDENT * indent + code)
return codes return codes
...@@ -451,7 +472,7 @@ class Writer(object): ...@@ -451,7 +472,7 @@ class Writer(object):
Writer.add_codes(codes, body_code, 1) Writer.add_codes(codes, body_code, 1)
fp = open(filename, 'w') fp = open(filename, 'w')
for code in _flatten_list(codes): for code in flatten_list(codes):
fp.write(code) fp.write(code)
fp.write('\n') fp.write('\n')
fp.close() fp.close()
......
...@@ -54,6 +54,8 @@ zip_safe = True ...@@ -54,6 +54,8 @@ zip_safe = True
[options.entry_points] [options.entry_points]
console_scripts = console_scripts =
onnx2fluid = onnx2fluid.__main__ onnx2fluid = onnx2fluid.__main__
onnx2fluid_convert = onnx2fluid.conversion
onnx2fluid_validate = onnx2fluid.validation
# 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下 # 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下
# 仅支持文件,不支持目录,但可以使用通配 # 仅支持文件,不支持目录,但可以使用通配
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册