提交 826481c4 编写于 作者: M Macrobull

add type shape inference

上级 7c3e9379
......@@ -14,14 +14,14 @@ from collections import OrderedDict as Dict
def make_var_name(name):
"""
make a valid variable name in Python code
"""
make a valid variable name in Python code
"""
if name == '':
return '_'
if name[0].isdigit():
return 'var_' + name
for s in ' \\|/:': #
for s in ' \\|/:-': #
name = name.replace(s, '_')
if name.startswith('_'):
name = 'var' + name
......
......@@ -17,14 +17,14 @@ from glob import glob
def make_var_name(name):
"""
make a valid variable name in Python code
"""
make a valid variable name in Python code
"""
if name == '':
return '_'
if name[0].isdigit():
return 'var_' + name
for s in ' \\|/:': #
for s in ' \\|/:-': #
name = name.replace(s, '_')
if name.startswith('_'):
name = 'var' + name
......
......@@ -311,7 +311,7 @@ resnet100_arcface()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid -o /tmp/export/ "$fn_model" -y
python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir ..."
......
......@@ -95,6 +95,14 @@ parser.add_argument(
default=1e-2,
help='assertion relative tolerance for validation',
)
parser.add_argument(
'--infer_inputs',
'-i',
nargs='?',
default=None,
const='',
help='perform type-shape inference with given input names and re-save model',
)
args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
......
......@@ -60,19 +60,28 @@ def main(**kwargs):
# validate
passed = True
golden_data_filename = kwargs.pop('test_data', '')
if golden_data_filename:
infer_inputs = kwargs.pop('infer_inputs', None)
if golden_data_filename or infer_inputs:
from .validation import validate
save_inference_model = infer_inputs is not None
inference_input_names = infer_inputs.split(
',') if infer_inputs else None
logger.info('starting validation on desc ...')
passed &= validate(shutil.os.path.join(save_dir, '__model__'),
golden_data_filename, **kwargs)
golden_data_filename=golden_data_filename,
save_inference_model=save_inference_model,
inference_input_names=inference_input_names,
**kwargs)
logger.info('starting validation on code ...')
# this re-generate desc proto with Python code when debug on
passed &= validate(shutil.os.path.join(save_dir, model_basename),
golden_data_filename,
golden_data_filename=golden_data_filename,
model_func_name=model_func_name,
save_inference_model=debug,
save_inference_model=save_inference_model,
inference_input_names=inference_input_names,
**kwargs)
if not passed:
......
......@@ -27,8 +27,8 @@ def convert(onnx_model_filename,
debug=False,
**kwargs):
"""
convert an ONNX model to Paddle fluid Python code and desc pb
"""
convert an ONNX model to Paddle fluid Python code and desc pb
"""
import onnx
......@@ -141,23 +141,22 @@ def convert(onnx_model_filename,
logger.info('%d ops in, %d ops out', len(onnx_graph.node),
len(fluid_program.op_descs))
# shape-inference
# type-shape inference
for name, value_info in graph_value_infos.items():
var_name = make_var_name(name)
fluid_program.VarTypeInfo(var_name, value_info,
remove_batch=False) # shape-infer only
fluid_program.VarTypeShapeInfo(var_name, value_info,
remove_batch=False) # shape-infer only
bad_var_names = []
for var_name, var_desc in fluid_program.var_descs.items():
if not var_desc.type.lod_tensor.HasField('tensor'):
bad_var_names.append(var_name)
if len(bad_var_names) > 0:
logger.warning('type info not infered for var %s ...',
logger.warning('type-shape not infered for var %s ...',
', '.join(bad_var_names[:5]))
logger.warning('this causes little problem for PaddlePaddle, '
'but Paddle Mobile may not infer correctly')
logger.warning(
'please consider adding option -d to invoke PaddlePaddle shape-inference'
)
logger.warning('please consider running onnx2fluid.validation with -i '
'to invoke PaddlePaddle type-shape inference')
# weight writer
for name, weight in graph_weights(onnx_graph):
......@@ -233,13 +232,9 @@ def convert(onnx_model_filename,
logger.info('conversion finished')
if __name__ == '__main__':
del convert
def main():
import argparse
from onnx2fluid.conversion import convert
parser = argparse.ArgumentParser(
description='onnx2fluid.convert',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
......@@ -310,3 +305,11 @@ if __name__ == '__main__':
onnx_opset_pedantic=pedantic,
onnx_skip_version_conversion=skip_version_conversion,
debug=debug)
if __name__ == '__main__':
del convert
from onnx2fluid.conversion import convert
main()
......@@ -44,8 +44,8 @@ DEFAULT_OP_DOMAIN = 'ai.onnx'
def print_pb_structure(message, loop_iterative=False, depth=0):
"""
print pb fields in its structure
"""
print pb fields in its structure
"""
if hasattr(message, 'DESCRIPTOR') and hasattr(message.DESCRIPTOR, 'fields'):
for field in message.DESCRIPTOR.fields:
......@@ -65,8 +65,8 @@ def print_pb_structure(message, loop_iterative=False, depth=0):
def build_value_refs(nodes):
"""
build op reference of inputs and outputs
"""
build op reference of inputs and outputs
"""
input_refs = Dict()
output_refs = Dict()
......@@ -80,8 +80,8 @@ def build_value_refs(nodes):
def get_attribute_value2(attr):
"""
get_attribute_value enhanced
"""
get_attribute_value enhanced
"""
if attr.type == onnx.AttributeProto.TENSOR:
dtype = np.dtype(TENSOR_TYPE_TO_NP_TYPE[attr.t.data_type])
......@@ -99,24 +99,24 @@ def get_attribute_value2(attr):
def tensor_dtype(tensor):
"""
get ONNX tensor in np.dtype
"""
get ONNX tensor in np.dtype
"""
return TENSOR_TYPE_TO_NP_TYPE[tensor.type.tensor_type.elem_type]
def tensor_shape(tensor):
"""
get ONNX tensor shape
"""
get ONNX tensor shape
"""
return [dim.dim_value for dim in tensor.type.tensor_type.shape.dim]
def node_attrs(node):
"""
convert ONNX node attributes to dict
"""
convert ONNX node attributes to dict
"""
return {attr.name: get_attribute_value2(attr)
for attr in node.attribute} # dict
......@@ -124,8 +124,8 @@ def node_attrs(node):
def node_topo(nodes, topo='default'):
"""
build indices with given topology to an ONNX node graph
"""
build indices with given topology to an ONNX node graph
"""
if topo == 'default':
return list(range(len(nodes)))
......@@ -192,8 +192,8 @@ def node_topo(nodes, topo='default'):
def node_iter(nodes, indices=None):
"""
generator for ONNX node graph with given indices
"""
generator for ONNX node graph with given indices
"""
if indices is None:
indices = range(len(nodes))
......@@ -210,7 +210,7 @@ def node_iter(nodes, indices=None):
if name == '':
name = 'op_' + str(index)
else: # make_op_name
for s in ' \\|/:': #
for s in ' \\|/:-': #
name = name.replace(s, '_')
if domain == '':
domain = DEFAULT_OP_DOMAIN
......@@ -220,8 +220,8 @@ def node_iter(nodes, indices=None):
def graph_ops(graph, topo='default'):
"""
generator for ONNX node graph with given topology
"""
generator for ONNX node graph with given topology
"""
if not isinstance(graph, onnx.GraphProto):
logger.error('graph is not a GraphProto instance')
......@@ -232,8 +232,8 @@ def graph_ops(graph, topo='default'):
def graph_weights(graph):
"""
generator for weights of an ONNX model
"""
generator for weights of an ONNX model
"""
if not isinstance(graph, onnx.GraphProto):
logger.error('graph is not a GraphProto instance')
......@@ -247,8 +247,8 @@ def graph_weights(graph):
def inferred_model_value_info(model):
"""
collect value/type info for an ONNX model
"""
collect value/type info for an ONNX model
"""
model = infer_shapes(model)
graph = model.graph
......@@ -278,8 +278,8 @@ def inferred_model_value_info(model):
def skip_node_forward(nodes, src_output_name, dst_input_name, input_refs):
"""
skip nodes between src_output_name -> dst_input_name and connect this pair
"""
skip nodes between src_output_name -> dst_input_name and connect this pair
"""
processed = 0
for next_idx in input_refs[src_output_name]:
......@@ -293,8 +293,8 @@ def skip_node_forward(nodes, src_output_name, dst_input_name, input_refs):
def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs):
"""
skip nodes between dst_output_name -> src_input_name and connect this pair
"""
skip nodes between dst_output_name -> src_input_name and connect this pair
"""
processed = 0
for prev_idx in output_refs[src_input_name]:
......@@ -308,8 +308,8 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs):
def optimize_model_skip_op_for_inference(model, op_list=None):
"""
skip ops can be bypassed for inference
"""
skip ops can be bypassed for inference
"""
if op_list is None:
op_list = ('Dropout', 'Identity')
......@@ -369,8 +369,8 @@ def optimize_model_skip_op_for_inference(model, op_list=None):
def optimize_model_strip_initializer(model, keep_input_only=True):
"""
strip weights for inference
"""
strip weights for inference
"""
nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes)
......@@ -410,8 +410,8 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
def optimize_model_cast(model):
"""
strip cascade and unecessary onnx::Cast-9:
"""
strip cascade and unecessary onnx::Cast-9:
"""
nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes)
......@@ -467,8 +467,8 @@ def optimize_model_cast(model):
def optimize_model_slice(model):
"""
strip cascade and unecessary onnx::Slice-1:9
"""
strip cascade and unecessary onnx::Slice-1:9
"""
nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes)
......
此差异已折叠。
......@@ -25,7 +25,8 @@ def ensure_tuple(obj):
def flatten_list(obj, out=None):
assert isinstance(obj, list)
assert isinstance(obj, list), 'list type required'
if out is None:
out = type(obj)()
for item in obj:
......@@ -38,11 +39,11 @@ def flatten_list(obj, out=None):
def export_data(state_dict, prefix=''):
"""
export binary data with meta text for raw C++ inference engines
"""
export binary data with meta text for raw C++ inference engines
"""
def str_(obj):
if isinstance(obj, (tuple, list)):
if isinstance(obj, (tuple, list, set)):
return str(obj)[1:-1].replace(' ', '')
return str(obj)
......@@ -72,8 +73,8 @@ def export_onnx_with_validation(model,
*args,
**kwargs):
"""
export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
"""
export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
"""
is_tuple_or_list = lambda x: isinstance(x, (tuple, list))
......
......@@ -10,14 +10,15 @@ import importlib, logging, os, sys
def flatten_dict(obj, out=None):
assert isinstance(obj, dict)
assert isinstance(obj, dict), 'dict type required'
if out is None:
out = type(obj)()
for key, value in obj.items():
if isinstance(value, dict):
flatten_dict(value, out)
else:
assert key not in out
assert key not in out, 'key conflicted'
out[key] = value
return out
......@@ -29,15 +30,16 @@ def ensure_list(obj):
def validate(fluid_model_filename,
golden_data_filename,
model_func_name='inference',
golden_data_filename='',
atol=1e-3,
rtol=1e-3,
model_func_name='inference',
save_inference_model=False,
inference_input_names=None,
**kwargs):
"""
inference the converted Paddle fluid model, validate with given golden data
"""
inference the converted Paddle fluid model, validate with given golden data
"""
import numpy as np
import paddle.fluid as fluid
......@@ -86,24 +88,50 @@ def validate(fluid_model_filename,
raise ValueError('unsupported Paddle fluid model filename')
# load data
logger.info('using golden data %s', golden_data_filename)
if golden_data_filename.endswith('.npz'):
test_data = np.load(golden_data_filename, encoding='bytes')
input_data = test_data['inputs'].tolist()
output_data = test_data['outputs'].tolist()
if golden_data_filename:
logger.info('using golden data %s', golden_data_filename)
if golden_data_filename.endswith('.npz'):
test_data = np.load(golden_data_filename, encoding='bytes')
input_data = test_data['inputs'].tolist()
output_data = test_data['outputs'].tolist()
else:
test_data = np.load(golden_data_filename, encoding='bytes').tolist()
input_data = test_data['inputs']
output_data = test_data['outputs']
input_data = flatten_dict(input_data)
output_data = flatten_dict(output_data)
input_names = input_data.keys()
logger.info('found %d I/O golden data, starting test ...',
len(input_data) + len(output_data))
else:
test_data = np.load(golden_data_filename, encoding='bytes').tolist()
input_data = test_data['inputs']
output_data = test_data['outputs']
input_data = flatten_dict(input_data)
output_data = flatten_dict(output_data)
logger.info('found %d I/O golden data, starting test ...',
len(input_data) + len(output_data))
# DEBUG: reload test for Python code
if basename.endswith('.py') and save_inference_model:
assert inference_input_names, 'input names required for type-shape inference'
input_names = inference_input_names
logger.info('using input names: %s', ', '.join(input_names))
# type-shape inference and re-save
if save_inference_model:
for block in prog.blocks:
block_desc = block.desc
for idx_op in range(block_desc.op_size()):
op_desc = block_desc.op(idx_op)
if op_desc.type() in ('feed', 'fetch'):
continue
op_desc.infer_var_type(block_desc)
op_desc.infer_shape(block_desc)
for var_name, var in block.vars.items():
var_desc = var.desc
if var_desc.type() != fluid.core.VarDesc.VarType.LOD_TENSOR:
continue
# WORKAROUND: dirty way to give dtype to partial-infered vars
# which could not be cleared!
try:
var.to_string(True)
except ValueError:
var_desc.set_dtype(fluid.core.VarDesc.VarType.FP32)
fluid.io.save_inference_model(fluid_model_dir,
input_data.keys(),
input_names,
var_outs,
exe,
main_program=prog,
......@@ -112,8 +140,12 @@ def validate(fluid_model_filename,
fluid.io.load_inference_model(fluid_model_dir, exe)
logger.info('model re-load passed')
if not golden_data_filename:
return True
# execute
outputs = exe.run(prog, feed=input_data, fetch_list=out_names)
outputs = exe.run(prog, feed=input_data,
fetch_list=out_names) # out_names can be vars
logger.info('execution passed')
# validate
......@@ -134,11 +166,10 @@ def validate(fluid_model_filename,
logger.info('accuracy passed')
else:
logger.info('accuracy not passed')
return passed
if __name__ == '__main__':
def main():
import argparse
parser = argparse.ArgumentParser(
......@@ -160,6 +191,7 @@ if __name__ == '__main__':
'--test_data',
'-t',
type=str,
default='',
help='I/O golden data for validation, e.g. test.npy, test.npz',
)
parser.add_argument(
......@@ -175,19 +207,36 @@ if __name__ == '__main__':
default=1e-2,
help='assertion relative tolerance for validation',
)
parser.add_argument(
'--infer_inputs',
'-i',
nargs='?',
default=None,
const='',
help=
'perform type-shape inference with given input names and re-save model',
)
args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(format=logging_format, level=logging_level)
debug = args.debug
# debug = args.debug
fluid_model_filename = args.model[0]
golden_data_filename = args.test_data
atol, rtol = args.atol, args.rtol
save_inference_model = args.infer_inputs is not None
inference_input_names = args.infer_inputs.split(
',') if args.infer_inputs else None
validate(fluid_model_filename,
golden_data_filename,
golden_data_filename=golden_data_filename,
atol=atol,
rtol=rtol,
save_inference_model=debug)
save_inference_model=save_inference_model,
inference_input_names=inference_input_names)
if __name__ == '__main__':
main()
......@@ -44,6 +44,8 @@ def irepr(obj, to='_'):
def flatten_list(obj, out=None):
assert isinstance(obj, list), 'list type required'
if out is None:
out = type(obj)()
for item in obj:
......@@ -56,12 +58,12 @@ def flatten_list(obj, out=None):
def make_attr_name(name):
"""
make a valid code name for ParamAttr
"""
make a valid code name for ParamAttr
"""
if name == '':
raise ValueError('name should not be empty')
for s in ' \\|/:': #
assert name != '', 'name should not be empty'
for s in ' \\|/:-': #
name = name.replace(s, '_')
if not name.startswith('_'):
name = '_' + name
......@@ -70,8 +72,8 @@ def make_attr_name(name):
class Program(object):
"""
fluid Python code and ProgramDesc wrapper
"""
fluid Python code and ProgramDesc wrapper
"""
DTYPE_TO_FRAMEWORK_DTYPE = {
'bool': framework_pb2.VarType.BOOL,
......@@ -88,8 +90,8 @@ class Program(object):
@staticmethod
def Dtype(dtype):
"""
convert dtype to fulid framework dtype
"""
convert dtype to fulid framework dtype
"""
dtype = np.dtype(dtype).name
return Program.DTYPE_TO_FRAMEWORK_DTYPE[dtype]
......@@ -97,8 +99,8 @@ class Program(object):
@staticmethod
def OpDescVars(vals, *keys):
"""
make (OpDesc.Var)s
"""
make (OpDesc.Var)s
"""
od_vars = []
for idx, key in enumerate(keys):
......@@ -112,8 +114,8 @@ class Program(object):
@staticmethod
def OpDescAttrs(attrs):
"""
make (OpDesc.Attr)s
"""
make (OpDesc.Attr)s
"""
od_attrs = []
for key, value in attrs.items():
......@@ -178,8 +180,8 @@ class Program(object):
def Code(self, code):
"""
add Python code
"""
add Python code
"""
if self.code_mutable:
self.codes.append(code)
......@@ -190,16 +192,16 @@ class Program(object):
output_val_keys=None,
attrs=None):
"""
add OpDesc
"""
add OpDesc
"""
desc = framework_pb2.OpDesc()
desc.type = op_type
if input_val_keys is not None:
if input_val_keys:
desc.inputs.extend(self.OpDescVars(*input_val_keys))
if output_val_keys is not None:
if output_val_keys:
desc.outputs.extend(self.OpDescVars(*output_val_keys))
if attrs is not None:
if attrs:
desc.attrs.extend(self.OpDescAttrs(attrs))
self.op_descs.append(desc)
return desc
......@@ -210,8 +212,8 @@ class Program(object):
value_info=None,
remove_batch=None):
"""
add VarDesc,
"""
add VarDesc,
"""
assert var_name not in self.var_descs, 'var naming conflicted'
......@@ -220,13 +222,16 @@ class Program(object):
var_desc.persistable = persistable
var_desc.type.type = framework_pb2.VarType.LOD_TENSOR
self.var_descs[var_name] = var_desc
if value_info:
self.VarTypeInfo(var_name, value_info, remove_batch=remove_batch)
self.VarTypeShapeInfo(var_name,
value_info,
remove_batch=remove_batch)
def Op(self, domain, op_type, *args, **kwargs):
"""
convert an ONNX op and add it to program
"""
convert an ONNX op and add it to program
"""
if domain != '': # TODO: symbolic file routing by domain
raise ValueError('only default domain supported')
......@@ -242,8 +247,8 @@ class Program(object):
def IntermediateOp(self, domain, op_type, *args, **kwargs):
"""
convert an intermediate ONNX op declaring in desc program only
"""
convert an intermediate ONNX op declaring in desc program only
"""
code_mutable = self.code_mutable
self.code_mutable = False
......@@ -255,10 +260,10 @@ class Program(object):
else:
self.code_mutable = code_mutable
def VarTypeInfo(self, var_name, value_info, remove_batch=None):
def VarTypeShapeInfo(self, var_name, value_info, remove_batch=None):
"""
set value_info for var
"""
set value_info for var
"""
if var_name not in self.var_descs:
return
......@@ -284,8 +289,8 @@ class Program(object):
class Writer(object):
"""
fluid code and desc writter
"""
fluid code and desc writter
"""
# CODE_INDENT = ' ' * 4
CODE_INDENT = '\t'
......@@ -293,8 +298,8 @@ class Writer(object):
@staticmethod
def header_code(func_name, info=''):
"""
Python header codes
"""
Python header codes
"""
codes = []
codes.append('"""')
......@@ -315,8 +320,8 @@ class Writer(object):
def emit_op(prog, name, domain, op_type, inputs, outputs, attrs,
value_infos, *args, **kwargs):
"""
emit an ONNX op into program
"""
emit an ONNX op into program
"""
prog.Code('# {}, {}::{}: {} -> {}, {}'.format(name, domain, op_type,
inputs, outputs,
......@@ -334,8 +339,8 @@ class Writer(object):
@staticmethod
def emit_param(prog, name, value_info):
"""
emit an ONNX weight into program
"""
emit an ONNX weight into program
"""
if value_info.get('embeded_as', []):
var_names = value_info['embeded_as']
......@@ -359,8 +364,8 @@ class Writer(object):
@staticmethod
def emit_inputs(prog, names, value_infos, remove_batch=None):
"""
emit ONNX inputs into program
"""
emit ONNX inputs into program
"""
for idx, name in enumerate(names):
var_name = make_var_name(name)
......@@ -396,8 +401,8 @@ class Writer(object):
@staticmethod
def emit_outputs(prog, names): #, value_infos
"""
emit ONNX outputs into program
"""
emit ONNX outputs into program
"""
code = 'return '
for idx, name in enumerate(names):
......@@ -416,8 +421,8 @@ class Writer(object):
@staticmethod
def add_codes(codes, others, indent):
"""
flatten codes in program
"""
flatten codes in program
"""
for code in flatten_list(others):
codes.append(Writer.CODE_INDENT * indent + code)
......@@ -426,11 +431,10 @@ class Writer(object):
@staticmethod
def write_weight(weight, filename):
"""
write single weight in fluid desc
"""
write single weight in fluid desc
"""
if not isinstance(weight, np.ndarray):
raise TypeError('weight is not an ndarray')
assert isinstance(weight, np.ndarray), 'weight is not an ndarray'
tensor_desc = framework_pb2.VarType.TensorDesc()
tensor_desc.data_type = Program.Dtype(weight.dtype)
......@@ -448,12 +452,11 @@ class Writer(object):
@staticmethod
def write_weights(weights, save_dir):
"""
write multiple weights in each fluid desc
"""
write multiple weights in each fluid desc
"""
for name, weight in weights.items():
if not isinstance(weights, dict):
raise TypeError('dict type weights required')
assert isinstance(weights, dict), 'dict type weights required'
var_name = make_var_name(name)
filename = os.path.join(save_dir, var_name)
......@@ -463,8 +466,8 @@ class Writer(object):
@staticmethod
def write_code_file(filename, header_code, *body_codes):
"""
write Python code to file
"""
write Python code to file
"""
codes = []
Writer.add_codes(codes, header_code, 0)
......@@ -481,8 +484,8 @@ class Writer(object):
@staticmethod
def write_desc_file(filename, op_descs, var_descs):
"""
write desc program to file
"""
write desc program to file
"""
prog_desc = framework_pb2.ProgramDesc()
block_desc = prog_desc.blocks.add()
......
......@@ -54,8 +54,8 @@ zip_safe = True
[options.entry_points]
console_scripts =
onnx2fluid = onnx2fluid.__main__
onnx2fluid_convert = onnx2fluid.conversion
onnx2fluid_validate = onnx2fluid.validation
onnx2fluid_convert = onnx2fluid.conversion:main
onnx2fluid_validate = onnx2fluid.validation:main
# 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下
# 仅支持文件,不支持目录,但可以使用通配
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册