提交 fcbdcb82 编写于 作者: M Macrobull

add ops and update readme

上级 2a82fdeb
...@@ -54,7 +54,7 @@ onnx2fluid sample_1.onnx -t sample_1.npz ...@@ -54,7 +54,7 @@ onnx2fluid sample_1.onnx -t sample_1.npz
onnx2fluid: onnx2fluid:
```shell ```shell
onnx2fluid [-dexy] [-o /path/to/export_dir/] [-z archive.zip] [-t test_data.npz] /path/to/onnx/model.onnx onnx2fluid [-dexy] [-o /path/to/export_dir/] [-z archive.zip] [-t test_data.npz] [-i [input_name1,input_name2]] /path/to/onnx/model.onnx
optional arguments: optional arguments:
--debug, -d 启用调试 --debug, -d 启用调试
...@@ -65,6 +65,8 @@ optional arguments: ...@@ -65,6 +65,8 @@ optional arguments:
--output_dir, -o 指定输出目录 --output_dir, -o 指定输出目录
--archive [ARCHIVE], -z [ARCHIVE] --archive [ARCHIVE], -z [ARCHIVE]
如果验证通过,打包到指定的ZIP文件 如果验证通过,打包到指定的ZIP文件
--infer_inputs, -i [input_name1,input_name2]
调用PaddlePaddle fluid类形推导完善模型
``` ```
转换工具onnx2fluid.conversion: 转换工具onnx2fluid.conversion:
...@@ -76,7 +78,7 @@ onnx2fluid.conversion [-dexy] [-o /path/to/export_dir/] /path/to/onnx/model.onnx ...@@ -76,7 +78,7 @@ onnx2fluid.conversion [-dexy] [-o /path/to/export_dir/] /path/to/onnx/model.onnx
验证工具onnx2fluid.validate: 验证工具onnx2fluid.validate:
```shell ```shell
onnx2fluid.validate [-d] [-t test_data.npz] [-p 1e-3] /path/to/onnx/model.onnx onnx2fluid.validate [-d] [-t test_data.npz] [-i [input_name1,input_name2]] [-p 1e-3] /path/to/onnx/model.onnx
``` ```
## 参考 ## 参考
......
...@@ -19,8 +19,8 @@ PyTorch to Paddlepaddle model conversion can be easily achieved with PyTorch ONN ...@@ -19,8 +19,8 @@ PyTorch to Paddlepaddle model conversion can be easily achieved with PyTorch ONN
## Environment and dependency ## Environment and dependency
* python 3.5+ (python 2 not fully supported yet) * python 3.5+ (python 2 not fully supported yet)
* onnx == 1.4.0 * onnx >= 1.4
* paddlepaddle == 1.3.0 (optional for validation) * paddlepaddle >= 1.3.0 (optional for validation)
## Get started ## Get started
...@@ -47,10 +47,12 @@ onnx2fluid sample_unet.onnx -t sample_unet.npz ...@@ -47,10 +47,12 @@ onnx2fluid sample_unet.onnx -t sample_unet.npz
## Usage ## Usage
**ONNX opset 9+** is mainly supported, corresponded to PyTorch **1.0/1.1(stable opset)**,for more information: [ONNX doc](https://github.com/onnx/onnx/blob/master/docs/Operators.md)
onnx2fluid (all in one): onnx2fluid (all in one):
```shell ```shell
onnx2fluid [-dexy] [-o /path/to/export_dir/] [-z archive.zip] [-t test_data.npz] /path/to/onnx/model.onnx onnx2fluid [-dexy] [-o /path/to/export_dir/] [-z archive.zip] [-t test_data.npz] [-i [input_name1,input_name2]] /path/to/onnx/model.onnx
optional arguments: optional arguments:
--debug, -d enable debug logging and checking --debug, -d enable debug logging and checking
...@@ -61,6 +63,8 @@ optional arguments: ...@@ -61,6 +63,8 @@ optional arguments:
--output_dir, -o output directory --output_dir, -o output directory
--archive [ARCHIVE], -z [ARCHIVE] --archive [ARCHIVE], -z [ARCHIVE]
compress outputs to ZIP file if conversion successed compress outputs to ZIP file if conversion successed
--infer_inputs, -i [input_name1,input_name2]
invoke PaddlePaddle fluid type-shape inference
``` ```
onnx2fluid.conversion: onnx2fluid.conversion:
...@@ -72,10 +76,10 @@ onnx2fluid.conversion [-dexy] [-o /path/to/export_dir/] /path/to/onnx/model.onnx ...@@ -72,10 +76,10 @@ onnx2fluid.conversion [-dexy] [-o /path/to/export_dir/] /path/to/onnx/model.onnx
onnx2fluid.validate: onnx2fluid.validate:
```shell ```shell
onnx2fluid.validate [-d] [-t test_data.npz] [-p 1e-3] /path/to/onnx/model.onnx onnx2fluid.validate [-d] [-t test_data.npz] [-i [input_name1,input_name2]] [-p 1e-3] /path/to/onnx/model.onnx
``` ```
## Reference ## Reference
* [PaddlePaddle fluid operators](http://www.paddlepaddle.org/documentation/docs/en/1.4/api/layers.html) * [PaddlePaddle fluid operators](http://www.paddlepaddle.org/documentation/docs/en/1.5/api/layers.html)
* load converted model via [load_inference_model](http://www.paddlepaddle.org/documentation/docs/en/1.4/api/io.html#permalink-1-load_inference_model) * load converted model via [load_inference_model](http://www.paddlepaddle.org/documentation/docs/en/1.5/api/io.html#permalink-1-load_inference_model)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 27 11:50:03 2019
@author: Macrobull
"""
import sys
import numpy as np
from collections import OrderedDict as Dict
def _make_var_name(name):
"""
make a valid variable name in Python code
"""
if name == '':
return '_'
if name[0].isdigit():
return 'var_' + name
for s in ' *?\\/-:':
name = name.replace(s, '_')
if name.startswith('_'):
name = 'var' + name
return name
fn = sys.argv[1]
input_names = sys.argv[2].split(':')
output_name = sys.argv[3].split(':')
squeeze_data = len(sys.argv) > 4
data = np.load(fn, encoding='bytes')
input_data = data['inputs']
output_data = data['outputs']
while squeeze_data and input_data.ndim > 4 and input_data.shape[0] == 1:
input_data = input_data.squeeze(0)
while squeeze_data and output_data.ndim > 2 and output_data.shape[0] == 1:
output_data = output_data.squeeze(0)
inputs = Dict(zip(map(_make_var_name, input_names), [input_data]))
outputs = Dict(zip(map(_make_var_name, output_name), [output_data]))
np.savez(fn, inputs=inputs, outputs=outputs) # overwrite
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 27 11:50:03 2019
@author: Macrobull
"""
import os, sys
import numpy as np
import onnx
import onnx.numpy_helper as numpy_helper
from collections import OrderedDict as Dict
from glob import glob
def _make_var_name(name):
"""
make a valid variable name in Python code
"""
if name == '':
return '_'
if name[0].isdigit():
return 'var_' + name
for s in ' *?\\/-:':
name = name.replace(s, '_')
if name.startswith('_'):
name = 'var' + name
return name
data_dir = os.path.dirname(sys.argv[1])
input_names = sys.argv[2].split(':')
output_name = sys.argv[3].split(':')
squeeze_data = len(sys.argv) > 4
# Load inputs
inputs = []
for fn in glob(os.path.join(data_dir, 'input_*.pb')):
tensor = onnx.TensorProto()
with open(fn, 'rb') as f:
tensor.ParseFromString(f.read())
tensor = numpy_helper.to_array(tensor)
while squeeze_data and tensor.ndim > 4 and tensor.shape[0] == 1:
tensor = tensor.squeeze(0)
inputs.append(tensor)
# Load outputs
outputs = []
for fn in glob(os.path.join(data_dir, 'output_*.pb')):
tensor = onnx.TensorProto()
with open(fn, 'rb') as f:
tensor.ParseFromString(f.read())
tensor = numpy_helper.to_array(tensor)
while squeeze_data and tensor.ndim > 2 and tensor.shape[0] == 1:
tensor = tensor.squeeze(0)
outputs.append(tensor)
inputs = Dict(zip(map(_make_var_name, input_names), inputs))
outputs = Dict(zip(map(_make_var_name, output_name), outputs))
np.savez(data_dir, inputs=inputs, outputs=outputs)
...@@ -20,34 +20,74 @@ from onnx2fluid.torch_export_helper import export_onnx_with_validation ...@@ -20,34 +20,74 @@ from onnx2fluid.torch_export_helper import export_onnx_with_validation
prefix = 'sample_' prefix = 'sample_'
idx = 0 idx = 0
######## example: RNN cell ########
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.gru = nn.GRUCell(6, 5)
self.lstm = nn.LSTMCell(5, 4)
def forward(self, x, h1, h2, c2):
h = self.gru(x, h1)
h, c = self.lstm(h, (h2, c2))
return h, c
model = Model()
model.eval()
xb = torch.rand((7, 6))
h1 = torch.zeros((7, 5))
h2 = torch.zeros((7, 4))
c2 = torch.zeros((7, 4))
yp = model(xb, h1, h2, c2)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, [xb, h1, h2, c2],
prefix + str(idx), ['x', 'h1', 'h2', 'c2'],
['h', 'c'],
verbose=True,
training=False)
######## example: RNN ######## ######## example: RNN ########
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super(Model, self).__init__() super(Model, self).__init__()
self.gru = nn.GRU(4, 5, 3) self.gru = nn.GRU(6, 5, 3)
self.lstm = nn.LSTM(5, 6, 2) self.lstm = nn.LSTM(5, 4, 2)
def forward(self, x): def forward(self, x, h1, h2, c2):
y = x y, h1 = self.gru(x, h1)
y, h = self.gru(y) y, (h2, c2) = self.lstm(y, (h2, c2))
y, h = self.lstm(y)
return y return y
model = Model() model = Model()
model.eval() model.eval()
xb = torch.rand((2, 3, 4)) xb = torch.rand((8, 1, 6))
yp = model(xb) h1 = torch.zeros((3, 1, 5))
h2 = torch.zeros((2, 1, 4))
c2 = torch.zeros((2, 1, 4))
yp = model(xb, h1, h2, c2)
idx += 1 idx += 1
print('index: ', idx) print('index: ', idx)
export_onnx_with_validation(model, [xb], export_onnx_with_validation(model, [xb, h1, h2, c2],
prefix + str(idx), ['x'], ['y'], prefix + str(idx), ['x', 'h1', 'h2', 'c2'], ['y'],
verbose=True, verbose=True,
training=False) training=False)
######## example: random ######## ######## example: random ########
"""
symbolic registration:
def rand(g, *shapes):
shapes_list = list(shapes)
shape = _maybe_get_const(shapes_list[0], "is")
return g.op('RandomUniform', shape_i=shape)
"""
class Model(nn.Module): class Model(nn.Module):
...@@ -55,8 +95,9 @@ class Model(nn.Module): ...@@ -55,8 +95,9 @@ class Model(nn.Module):
super(Model, self).__init__() super(Model, self).__init__()
def forward(self, x): def forward(self, x):
y = torch.rand((2, 3)) # + torch.rand_like(xb) y = torch.rand((2, 3)) # + torch.rand_like(x)
y = y + torch.randn((2, 3)) # + torch.randn_like(xb) y = y + torch.randn((2, 3)) # + torch.randn_like(x)
y = y + x
return y return y
...@@ -124,6 +165,13 @@ export_onnx_with_validation(model, [xb0, xb1], ...@@ -124,6 +165,13 @@ export_onnx_with_validation(model, [xb0, xb1],
training=False) training=False)
######## example: affine_grid ######## ######## example: affine_grid ########
"""
symbolic registration:
@parse_args('v', 'is')
def affine_grid_generator(g, theta, size):
return g.op('AffineGrid', theta, size_i=size)
"""
class Model(nn.Module): class Model(nn.Module):
......
...@@ -61,7 +61,7 @@ def main(**kwargs): ...@@ -61,7 +61,7 @@ def main(**kwargs):
passed = True passed = True
golden_data_filename = kwargs.pop('test_data', '') golden_data_filename = kwargs.pop('test_data', '')
infer_inputs = kwargs.pop('infer_inputs', None) infer_inputs = kwargs.pop('infer_inputs', None)
if golden_data_filename or infer_inputs: if golden_data_filename or infer_inputs is not None:
from .validation import validate from .validation import validate
save_inference_model = infer_inputs is not None save_inference_model = infer_inputs is not None
......
...@@ -91,7 +91,7 @@ def convert(onnx_model_filename, ...@@ -91,7 +91,7 @@ def convert(onnx_model_filename,
# onnx model optimization # onnx model optimization
logger.info('model has %d ops', len(onnx_model.graph.node)) logger.info('model has %d ops', len(onnx_model.graph.node))
logger.info('optimizing model ...') logger.info('optimizing model ...')
onnx_model = polish_model(onnx_model) onnx_model = polish_model(onnx_model, checking=onnx_opset_pedantic)
# prepare filesystem # prepare filesystem
shutil.rmtree(save_dir, ignore_errors=True) shutil.rmtree(save_dir, ignore_errors=True)
...@@ -123,6 +123,7 @@ def convert(onnx_model_filename, ...@@ -123,6 +123,7 @@ def convert(onnx_model_filename,
for name, weight in graph_weights(onnx_graph): for name, weight in graph_weights(onnx_graph):
var_name = make_var_name(name) var_name = make_var_name(name)
value_info = value_infos[var_name] value_info = value_infos[var_name]
value_info['lod'] = [0]
value_info['embedded_as'] = [] value_info['embedded_as'] = []
value_info['get_weight'] = (lambda w: lambda: w.tolist())( value_info['get_weight'] = (lambda w: lambda: w.tolist())(
weight) # lazy getter weight) # lazy getter
...@@ -134,8 +135,8 @@ def convert(onnx_model_filename, ...@@ -134,8 +135,8 @@ def convert(onnx_model_filename,
for name, domain, op_type, inputs, outputs, attrs in graph_ops(onnx_graph, for name, domain, op_type, inputs, outputs, attrs in graph_ops(onnx_graph,
topo=topo): topo=topo):
op_name = make_var_name(name) op_name = make_var_name(name)
inputs = [make_var_name(val) for val in inputs] inputs = list(map(make_var_name, inputs))
outputs = [make_var_name(val) for val in outputs] outputs = list(map(make_var_name, outputs))
logger.debug('translating op %s(%s) %s::%s ...', name, op_name, domain, logger.debug('translating op %s(%s) %s::%s ...', name, op_name, domain,
op_type) op_type)
if domain == DEFAULT_OP_DOMAIN: if domain == DEFAULT_OP_DOMAIN:
...@@ -192,13 +193,16 @@ def convert(onnx_model_filename, ...@@ -192,13 +193,16 @@ def convert(onnx_model_filename,
weight.dtype, weight.size, weight.nbytes, weight.dtype, weight.size, weight.nbytes,
embedded_names) embedded_names)
for embedded_name in embedded_names: # multiple references for embedded_name in embedded_names: # multiple references
fluid_writer.write_weight( fluid_writer.write_weight(weight,
weight, shutil.os.path.join(save_dir, embedded_name)) shutil.os.path.join(
save_dir, embedded_name),
lod=value_info['lod'])
else: else:
logger.debug('saving weight %s(%s[%d], %dB) to %s ...', name, logger.debug('saving weight %s(%s[%d], %dB) to %s ...', name,
weight.dtype, weight.size, weight.nbytes, var_name) weight.dtype, weight.size, weight.nbytes, var_name)
fluid_writer.write_weight(weight, fluid_writer.write_weight(weight,
shutil.os.path.join(save_dir, var_name)) shutil.os.path.join(save_dir, var_name),
lod=value_info['lod'])
fluid_writer.emit_param(fluid_program, var_name, value_info) fluid_writer.emit_param(fluid_program, var_name, value_info)
param_codes = fluid_program.codes param_codes = fluid_program.codes
fluid_program.codes = [] fluid_program.codes = []
......
...@@ -319,17 +319,20 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs): ...@@ -319,17 +319,20 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs):
return processed return processed
def polish_model(model, extras=True): def polish_model(model, internals=True, extras=True, checking=True):
""" """
polish_model enhanced for inference polish_model enhanced for inference
""" """
check_model(model) if checking:
check_model(model)
strip_doc_string(model) strip_doc_string(model)
passes = optimizer.get_available_passes() if internals:
passes = list(filter(lambda name: not name.startswith('split_'), passes)) # passes = optimizer.get_available_passes()
logger.debug('builtin optimizations to perform in ONNX:\n\t%s', passes) passes = list(filter(lambda name: not name.startswith('split_'),
model = optimizer.optimize(model, passes=passes) passes)) #
logger.debug('builtin optimizations to perform in ONNX:\n\t%s', passes)
model = optimizer.optimize(model, passes=passes)
if extras: if extras:
for optimize in ( for optimize in (
optimize_model_skip_op_for_inference, optimize_model_skip_op_for_inference,
...@@ -339,7 +342,8 @@ def polish_model(model, extras=True): ...@@ -339,7 +342,8 @@ def polish_model(model, extras=True):
): ):
model = optimize(model) model = optimize(model)
model = infer_shapes(model) model = infer_shapes(model)
check_model(model) if checking:
check_model(model)
return model return model
......
此差异已折叠。
...@@ -159,7 +159,7 @@ def validate(fluid_model_filename, ...@@ -159,7 +159,7 @@ def validate(fluid_model_filename,
# output_names = output_data.keys() # output_names = output_data.keys()
logger.info('with %d inputs and %d outputs', len(input_data), logger.info('with %d inputs and %d outputs', len(input_data),
len(output_data)) len(output_data))
else: elif save_inference_model:
assert inference_input_names, 'input names required for type-shape inference' assert inference_input_names, 'input names required for type-shape inference'
input_names = inference_input_names input_names = inference_input_names
......
...@@ -96,7 +96,7 @@ class Program(object): ...@@ -96,7 +96,7 @@ class Program(object):
return Program.DTYPE_TO_FRAMEWORK_DTYPE[dtype] return Program.DTYPE_TO_FRAMEWORK_DTYPE[dtype]
@staticmethod @staticmethod
def OpDescVars(vals, *keys): def OpDescVars(keys, vals):
""" """
make (OpDesc.Var)s make (OpDesc.Var)s
""" """
...@@ -150,13 +150,11 @@ class Program(object): ...@@ -150,13 +150,11 @@ class Program(object):
else: else:
raise ValueError('unsupported attribute {} = {}'.format( raise ValueError('unsupported attribute {} = {}'.format(
key, value)) key, value))
else: # WORKAROUND: shape of scalars is [] else: # WORKAROUND: [] not inferred
raise ValueError('unsupported attribute {} = {}'.format( # raise ValueError('unsupported attribute {} = {}'.format(key, value))
key, value)) od_attr.type = framework_pb2.INTS
logger.warning('using attribute %s = %s as INTS', key,
value)
# od_attr.type = framework_pb2.INTS
# logger.warning('using attribute %s = %s as INTS', key, value)
else: else:
raise ValueError('unsupported attribute {} = {}'.format( raise ValueError('unsupported attribute {} = {}'.format(
key, value)) key, value))
...@@ -187,8 +185,8 @@ class Program(object): ...@@ -187,8 +185,8 @@ class Program(object):
def OpDesc(self, def OpDesc(self,
op_type, op_type,
input_val_keys=None, input_key_vals=None,
output_val_keys=None, output_key_vals=None,
attrs=None): attrs=None):
""" """
add OpDesc add OpDesc
...@@ -196,10 +194,10 @@ class Program(object): ...@@ -196,10 +194,10 @@ class Program(object):
desc = framework_pb2.OpDesc() desc = framework_pb2.OpDesc()
desc.type = op_type desc.type = op_type
if input_val_keys: if input_key_vals:
desc.inputs.extend(self.OpDescVars(*input_val_keys)) desc.inputs.extend(self.OpDescVars(*input_key_vals))
if output_val_keys: if output_key_vals:
desc.outputs.extend(self.OpDescVars(*output_val_keys)) desc.outputs.extend(self.OpDescVars(*output_key_vals))
if attrs: if attrs:
desc.attrs.extend(self.OpDescAttrs(attrs)) desc.attrs.extend(self.OpDescAttrs(attrs))
self.op_descs.append(desc) self.op_descs.append(desc)
...@@ -388,8 +386,8 @@ class Writer(object): ...@@ -388,8 +386,8 @@ class Writer(object):
)) ))
prog.OpDesc( prog.OpDesc(
'feed', 'feed',
(['feed'], 'X'), (['X'], ['feed']),
([name], 'Out'), (['Out'], [name]),
{'col': idx}, {'col': idx},
) )
prog.VarDesc(name, value_info=value_info, remove_batch=remove_batch) prog.VarDesc(name, value_info=value_info, remove_batch=remove_batch)
...@@ -406,8 +404,8 @@ class Writer(object): ...@@ -406,8 +404,8 @@ class Writer(object):
prog.OpDesc( prog.OpDesc(
'fetch', 'fetch',
([name], 'X'), (['X'], [name]),
(['fetch'], 'Out'), (['Out'], ['fetch']),
{'col': idx}, {'col': idx},
) )
# var is emitted over ops # var is emitted over ops
...@@ -424,12 +422,16 @@ class Writer(object): ...@@ -424,12 +422,16 @@ class Writer(object):
return codes return codes
@staticmethod @staticmethod
def write_weight(weight, filename): def write_weight(weight, filename, lod=None):
""" """
write single weight in fluid desc write single weight in fluid desc
""" """
assert isinstance(weight, np.ndarray), 'weight is not an ndarray' assert isinstance(weight, np.ndarray), 'weight is not an ndarray'
assert lod is None or isinstance(lod,
list), 'lod should be None or list'
lod = lod or [0]
tensor_desc = framework_pb2.VarType.TensorDesc() tensor_desc = framework_pb2.VarType.TensorDesc()
tensor_desc.data_type = Program.Dtype(weight.dtype) tensor_desc.data_type = Program.Dtype(weight.dtype)
...@@ -437,7 +439,7 @@ class Writer(object): ...@@ -437,7 +439,7 @@ class Writer(object):
fp = open(filename, 'wb') fp = open(filename, 'wb')
np.array([0], dtype=np.int32).tofile(fp) # version np.array([0], dtype=np.int32).tofile(fp) # version
np.array([0], dtype=np.int64).tofile(fp) # LOD level np.array(lod, dtype=np.int64).tofile(fp) # LOD level
np.array([0], dtype=np.int32).tofile(fp) # tensor version np.array([0], dtype=np.int32).tofile(fp) # tensor version
np.array([tensor_desc.ByteSize()], dtype=np.int32).tofile(fp) np.array([tensor_desc.ByteSize()], dtype=np.int32).tofile(fp)
fp.write(tensor_desc.SerializeToString()) fp.write(tensor_desc.SerializeToString())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册