提交 826481c4 编写于 作者: M Macrobull

add type shape inference

上级 7c3e9379
...@@ -14,14 +14,14 @@ from collections import OrderedDict as Dict ...@@ -14,14 +14,14 @@ from collections import OrderedDict as Dict
def make_var_name(name): def make_var_name(name):
""" """
make a valid variable name in Python code make a valid variable name in Python code
""" """
if name == '': if name == '':
return '_' return '_'
if name[0].isdigit(): if name[0].isdigit():
return 'var_' + name return 'var_' + name
for s in ' \\|/:': # for s in ' \\|/:-': #
name = name.replace(s, '_') name = name.replace(s, '_')
if name.startswith('_'): if name.startswith('_'):
name = 'var' + name name = 'var' + name
......
...@@ -17,14 +17,14 @@ from glob import glob ...@@ -17,14 +17,14 @@ from glob import glob
def make_var_name(name): def make_var_name(name):
""" """
make a valid variable name in Python code make a valid variable name in Python code
""" """
if name == '': if name == '':
return '_' return '_'
if name[0].isdigit(): if name[0].isdigit():
return 'var_' + name return 'var_' + name
for s in ' \\|/:': # for s in ' \\|/:-': #
name = name.replace(s, '_') name = name.replace(s, '_')
if name.startswith('_'): if name.startswith('_'):
name = 'var' + name name = 'var' + name
......
...@@ -311,7 +311,7 @@ resnet100_arcface() ...@@ -311,7 +311,7 @@ resnet100_arcface()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid -o /tmp/export/ "$fn_model" -y python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar"/*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
......
...@@ -95,6 +95,14 @@ parser.add_argument( ...@@ -95,6 +95,14 @@ parser.add_argument(
default=1e-2, default=1e-2,
help='assertion relative tolerance for validation', help='assertion relative tolerance for validation',
) )
parser.add_argument(
'--infer_inputs',
'-i',
nargs='?',
default=None,
const='',
help='perform type-shape inference with given input names and re-save model',
)
args = parser.parse_args() args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s' logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
......
...@@ -60,19 +60,28 @@ def main(**kwargs): ...@@ -60,19 +60,28 @@ def main(**kwargs):
# validate # validate
passed = True passed = True
golden_data_filename = kwargs.pop('test_data', '') golden_data_filename = kwargs.pop('test_data', '')
if golden_data_filename: infer_inputs = kwargs.pop('infer_inputs', None)
if golden_data_filename or infer_inputs:
from .validation import validate from .validation import validate
save_inference_model = infer_inputs is not None
inference_input_names = infer_inputs.split(
',') if infer_inputs else None
logger.info('starting validation on desc ...') logger.info('starting validation on desc ...')
passed &= validate(shutil.os.path.join(save_dir, '__model__'), passed &= validate(shutil.os.path.join(save_dir, '__model__'),
golden_data_filename, **kwargs) golden_data_filename=golden_data_filename,
save_inference_model=save_inference_model,
inference_input_names=inference_input_names,
**kwargs)
logger.info('starting validation on code ...') logger.info('starting validation on code ...')
# this re-generate desc proto with Python code when debug on # this re-generate desc proto with Python code when debug on
passed &= validate(shutil.os.path.join(save_dir, model_basename), passed &= validate(shutil.os.path.join(save_dir, model_basename),
golden_data_filename, golden_data_filename=golden_data_filename,
model_func_name=model_func_name, model_func_name=model_func_name,
save_inference_model=debug, save_inference_model=save_inference_model,
inference_input_names=inference_input_names,
**kwargs) **kwargs)
if not passed: if not passed:
......
...@@ -27,8 +27,8 @@ def convert(onnx_model_filename, ...@@ -27,8 +27,8 @@ def convert(onnx_model_filename,
debug=False, debug=False,
**kwargs): **kwargs):
""" """
convert an ONNX model to Paddle fluid Python code and desc pb convert an ONNX model to Paddle fluid Python code and desc pb
""" """
import onnx import onnx
...@@ -141,23 +141,22 @@ def convert(onnx_model_filename, ...@@ -141,23 +141,22 @@ def convert(onnx_model_filename,
logger.info('%d ops in, %d ops out', len(onnx_graph.node), logger.info('%d ops in, %d ops out', len(onnx_graph.node),
len(fluid_program.op_descs)) len(fluid_program.op_descs))
# shape-inference # type-shape inference
for name, value_info in graph_value_infos.items(): for name, value_info in graph_value_infos.items():
var_name = make_var_name(name) var_name = make_var_name(name)
fluid_program.VarTypeInfo(var_name, value_info, fluid_program.VarTypeShapeInfo(var_name, value_info,
remove_batch=False) # shape-infer only remove_batch=False) # shape-infer only
bad_var_names = [] bad_var_names = []
for var_name, var_desc in fluid_program.var_descs.items(): for var_name, var_desc in fluid_program.var_descs.items():
if not var_desc.type.lod_tensor.HasField('tensor'): if not var_desc.type.lod_tensor.HasField('tensor'):
bad_var_names.append(var_name) bad_var_names.append(var_name)
if len(bad_var_names) > 0: if len(bad_var_names) > 0:
logger.warning('type info not infered for var %s ...', logger.warning('type-shape not infered for var %s ...',
', '.join(bad_var_names[:5])) ', '.join(bad_var_names[:5]))
logger.warning('this causes little problem for PaddlePaddle, ' logger.warning('this causes little problem for PaddlePaddle, '
'but Paddle Mobile may not infer correctly') 'but Paddle Mobile may not infer correctly')
logger.warning( logger.warning('please consider running onnx2fluid.validation with -i '
'please consider adding option -d to invoke PaddlePaddle shape-inference' 'to invoke PaddlePaddle type-shape inference')
)
# weight writer # weight writer
for name, weight in graph_weights(onnx_graph): for name, weight in graph_weights(onnx_graph):
...@@ -233,13 +232,9 @@ def convert(onnx_model_filename, ...@@ -233,13 +232,9 @@ def convert(onnx_model_filename,
logger.info('conversion finished') logger.info('conversion finished')
if __name__ == '__main__': def main():
del convert
import argparse import argparse
from onnx2fluid.conversion import convert
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='onnx2fluid.convert', description='onnx2fluid.convert',
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
...@@ -310,3 +305,11 @@ if __name__ == '__main__': ...@@ -310,3 +305,11 @@ if __name__ == '__main__':
onnx_opset_pedantic=pedantic, onnx_opset_pedantic=pedantic,
onnx_skip_version_conversion=skip_version_conversion, onnx_skip_version_conversion=skip_version_conversion,
debug=debug) debug=debug)
if __name__ == '__main__':
del convert
from onnx2fluid.conversion import convert
main()
...@@ -44,8 +44,8 @@ DEFAULT_OP_DOMAIN = 'ai.onnx' ...@@ -44,8 +44,8 @@ DEFAULT_OP_DOMAIN = 'ai.onnx'
def print_pb_structure(message, loop_iterative=False, depth=0): def print_pb_structure(message, loop_iterative=False, depth=0):
""" """
print pb fields in its structure print pb fields in its structure
""" """
if hasattr(message, 'DESCRIPTOR') and hasattr(message.DESCRIPTOR, 'fields'): if hasattr(message, 'DESCRIPTOR') and hasattr(message.DESCRIPTOR, 'fields'):
for field in message.DESCRIPTOR.fields: for field in message.DESCRIPTOR.fields:
...@@ -65,8 +65,8 @@ def print_pb_structure(message, loop_iterative=False, depth=0): ...@@ -65,8 +65,8 @@ def print_pb_structure(message, loop_iterative=False, depth=0):
def build_value_refs(nodes): def build_value_refs(nodes):
""" """
build op reference of inputs and outputs build op reference of inputs and outputs
""" """
input_refs = Dict() input_refs = Dict()
output_refs = Dict() output_refs = Dict()
...@@ -80,8 +80,8 @@ def build_value_refs(nodes): ...@@ -80,8 +80,8 @@ def build_value_refs(nodes):
def get_attribute_value2(attr): def get_attribute_value2(attr):
""" """
get_attribute_value enhanced get_attribute_value enhanced
""" """
if attr.type == onnx.AttributeProto.TENSOR: if attr.type == onnx.AttributeProto.TENSOR:
dtype = np.dtype(TENSOR_TYPE_TO_NP_TYPE[attr.t.data_type]) dtype = np.dtype(TENSOR_TYPE_TO_NP_TYPE[attr.t.data_type])
...@@ -99,24 +99,24 @@ def get_attribute_value2(attr): ...@@ -99,24 +99,24 @@ def get_attribute_value2(attr):
def tensor_dtype(tensor): def tensor_dtype(tensor):
""" """
get ONNX tensor in np.dtype get ONNX tensor in np.dtype
""" """
return TENSOR_TYPE_TO_NP_TYPE[tensor.type.tensor_type.elem_type] return TENSOR_TYPE_TO_NP_TYPE[tensor.type.tensor_type.elem_type]
def tensor_shape(tensor): def tensor_shape(tensor):
""" """
get ONNX tensor shape get ONNX tensor shape
""" """
return [dim.dim_value for dim in tensor.type.tensor_type.shape.dim] return [dim.dim_value for dim in tensor.type.tensor_type.shape.dim]
def node_attrs(node): def node_attrs(node):
""" """
convert ONNX node attributes to dict convert ONNX node attributes to dict
""" """
return {attr.name: get_attribute_value2(attr) return {attr.name: get_attribute_value2(attr)
for attr in node.attribute} # dict for attr in node.attribute} # dict
...@@ -124,8 +124,8 @@ def node_attrs(node): ...@@ -124,8 +124,8 @@ def node_attrs(node):
def node_topo(nodes, topo='default'): def node_topo(nodes, topo='default'):
""" """
build indices with given topology to an ONNX node graph build indices with given topology to an ONNX node graph
""" """
if topo == 'default': if topo == 'default':
return list(range(len(nodes))) return list(range(len(nodes)))
...@@ -192,8 +192,8 @@ def node_topo(nodes, topo='default'): ...@@ -192,8 +192,8 @@ def node_topo(nodes, topo='default'):
def node_iter(nodes, indices=None): def node_iter(nodes, indices=None):
""" """
generator for ONNX node graph with given indices generator for ONNX node graph with given indices
""" """
if indices is None: if indices is None:
indices = range(len(nodes)) indices = range(len(nodes))
...@@ -210,7 +210,7 @@ def node_iter(nodes, indices=None): ...@@ -210,7 +210,7 @@ def node_iter(nodes, indices=None):
if name == '': if name == '':
name = 'op_' + str(index) name = 'op_' + str(index)
else: # make_op_name else: # make_op_name
for s in ' \\|/:': # for s in ' \\|/:-': #
name = name.replace(s, '_') name = name.replace(s, '_')
if domain == '': if domain == '':
domain = DEFAULT_OP_DOMAIN domain = DEFAULT_OP_DOMAIN
...@@ -220,8 +220,8 @@ def node_iter(nodes, indices=None): ...@@ -220,8 +220,8 @@ def node_iter(nodes, indices=None):
def graph_ops(graph, topo='default'): def graph_ops(graph, topo='default'):
""" """
generator for ONNX node graph with given topology generator for ONNX node graph with given topology
""" """
if not isinstance(graph, onnx.GraphProto): if not isinstance(graph, onnx.GraphProto):
logger.error('graph is not a GraphProto instance') logger.error('graph is not a GraphProto instance')
...@@ -232,8 +232,8 @@ def graph_ops(graph, topo='default'): ...@@ -232,8 +232,8 @@ def graph_ops(graph, topo='default'):
def graph_weights(graph): def graph_weights(graph):
""" """
generator for weights of an ONNX model generator for weights of an ONNX model
""" """
if not isinstance(graph, onnx.GraphProto): if not isinstance(graph, onnx.GraphProto):
logger.error('graph is not a GraphProto instance') logger.error('graph is not a GraphProto instance')
...@@ -247,8 +247,8 @@ def graph_weights(graph): ...@@ -247,8 +247,8 @@ def graph_weights(graph):
def inferred_model_value_info(model): def inferred_model_value_info(model):
""" """
collect value/type info for an ONNX model collect value/type info for an ONNX model
""" """
model = infer_shapes(model) model = infer_shapes(model)
graph = model.graph graph = model.graph
...@@ -278,8 +278,8 @@ def inferred_model_value_info(model): ...@@ -278,8 +278,8 @@ def inferred_model_value_info(model):
def skip_node_forward(nodes, src_output_name, dst_input_name, input_refs): def skip_node_forward(nodes, src_output_name, dst_input_name, input_refs):
""" """
skip nodes between src_output_name -> dst_input_name and connect this pair skip nodes between src_output_name -> dst_input_name and connect this pair
""" """
processed = 0 processed = 0
for next_idx in input_refs[src_output_name]: for next_idx in input_refs[src_output_name]:
...@@ -293,8 +293,8 @@ def skip_node_forward(nodes, src_output_name, dst_input_name, input_refs): ...@@ -293,8 +293,8 @@ def skip_node_forward(nodes, src_output_name, dst_input_name, input_refs):
def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs): def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs):
""" """
skip nodes between dst_output_name -> src_input_name and connect this pair skip nodes between dst_output_name -> src_input_name and connect this pair
""" """
processed = 0 processed = 0
for prev_idx in output_refs[src_input_name]: for prev_idx in output_refs[src_input_name]:
...@@ -308,8 +308,8 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs): ...@@ -308,8 +308,8 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs):
def optimize_model_skip_op_for_inference(model, op_list=None): def optimize_model_skip_op_for_inference(model, op_list=None):
""" """
skip ops can be bypassed for inference skip ops can be bypassed for inference
""" """
if op_list is None: if op_list is None:
op_list = ('Dropout', 'Identity') op_list = ('Dropout', 'Identity')
...@@ -369,8 +369,8 @@ def optimize_model_skip_op_for_inference(model, op_list=None): ...@@ -369,8 +369,8 @@ def optimize_model_skip_op_for_inference(model, op_list=None):
def optimize_model_strip_initializer(model, keep_input_only=True): def optimize_model_strip_initializer(model, keep_input_only=True):
""" """
strip weights for inference strip weights for inference
""" """
nodes = model.graph.node nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes) input_refs, output_refs = build_value_refs(nodes)
...@@ -410,8 +410,8 @@ def optimize_model_strip_initializer(model, keep_input_only=True): ...@@ -410,8 +410,8 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
def optimize_model_cast(model): def optimize_model_cast(model):
""" """
strip cascade and unecessary onnx::Cast-9: strip cascade and unecessary onnx::Cast-9:
""" """
nodes = model.graph.node nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes) input_refs, output_refs = build_value_refs(nodes)
...@@ -467,8 +467,8 @@ def optimize_model_cast(model): ...@@ -467,8 +467,8 @@ def optimize_model_cast(model):
def optimize_model_slice(model): def optimize_model_slice(model):
""" """
strip cascade and unecessary onnx::Slice-1:9 strip cascade and unecessary onnx::Slice-1:9
""" """
nodes = model.graph.node nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes) input_refs, output_refs = build_value_refs(nodes)
......
此差异已折叠。
...@@ -25,7 +25,8 @@ def ensure_tuple(obj): ...@@ -25,7 +25,8 @@ def ensure_tuple(obj):
def flatten_list(obj, out=None): def flatten_list(obj, out=None):
assert isinstance(obj, list) assert isinstance(obj, list), 'list type required'
if out is None: if out is None:
out = type(obj)() out = type(obj)()
for item in obj: for item in obj:
...@@ -38,11 +39,11 @@ def flatten_list(obj, out=None): ...@@ -38,11 +39,11 @@ def flatten_list(obj, out=None):
def export_data(state_dict, prefix=''): def export_data(state_dict, prefix=''):
""" """
export binary data with meta text for raw C++ inference engines export binary data with meta text for raw C++ inference engines
""" """
def str_(obj): def str_(obj):
if isinstance(obj, (tuple, list)): if isinstance(obj, (tuple, list, set)):
return str(obj)[1:-1].replace(' ', '') return str(obj)[1:-1].replace(' ', '')
return str(obj) return str(obj)
...@@ -72,8 +73,8 @@ def export_onnx_with_validation(model, ...@@ -72,8 +73,8 @@ def export_onnx_with_validation(model,
*args, *args,
**kwargs): **kwargs):
""" """
export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
""" """
is_tuple_or_list = lambda x: isinstance(x, (tuple, list)) is_tuple_or_list = lambda x: isinstance(x, (tuple, list))
......
...@@ -10,14 +10,15 @@ import importlib, logging, os, sys ...@@ -10,14 +10,15 @@ import importlib, logging, os, sys
def flatten_dict(obj, out=None): def flatten_dict(obj, out=None):
assert isinstance(obj, dict) assert isinstance(obj, dict), 'dict type required'
if out is None: if out is None:
out = type(obj)() out = type(obj)()
for key, value in obj.items(): for key, value in obj.items():
if isinstance(value, dict): if isinstance(value, dict):
flatten_dict(value, out) flatten_dict(value, out)
else: else:
assert key not in out assert key not in out, 'key conflicted'
out[key] = value out[key] = value
return out return out
...@@ -29,15 +30,16 @@ def ensure_list(obj): ...@@ -29,15 +30,16 @@ def ensure_list(obj):
def validate(fluid_model_filename, def validate(fluid_model_filename,
golden_data_filename, golden_data_filename='',
model_func_name='inference',
atol=1e-3, atol=1e-3,
rtol=1e-3, rtol=1e-3,
model_func_name='inference',
save_inference_model=False, save_inference_model=False,
inference_input_names=None,
**kwargs): **kwargs):
""" """
inference the converted Paddle fluid model, validate with given golden data inference the converted Paddle fluid model, validate with given golden data
""" """
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -86,24 +88,50 @@ def validate(fluid_model_filename, ...@@ -86,24 +88,50 @@ def validate(fluid_model_filename,
raise ValueError('unsupported Paddle fluid model filename') raise ValueError('unsupported Paddle fluid model filename')
# load data # load data
logger.info('using golden data %s', golden_data_filename) if golden_data_filename:
if golden_data_filename.endswith('.npz'): logger.info('using golden data %s', golden_data_filename)
test_data = np.load(golden_data_filename, encoding='bytes') if golden_data_filename.endswith('.npz'):
input_data = test_data['inputs'].tolist() test_data = np.load(golden_data_filename, encoding='bytes')
output_data = test_data['outputs'].tolist() input_data = test_data['inputs'].tolist()
output_data = test_data['outputs'].tolist()
else:
test_data = np.load(golden_data_filename, encoding='bytes').tolist()
input_data = test_data['inputs']
output_data = test_data['outputs']
input_data = flatten_dict(input_data)
output_data = flatten_dict(output_data)
input_names = input_data.keys()
logger.info('found %d I/O golden data, starting test ...',
len(input_data) + len(output_data))
else: else:
test_data = np.load(golden_data_filename, encoding='bytes').tolist() assert inference_input_names, 'input names required for type-shape inference'
input_data = test_data['inputs']
output_data = test_data['outputs'] input_names = inference_input_names
input_data = flatten_dict(input_data) logger.info('using input names: %s', ', '.join(input_names))
output_data = flatten_dict(output_data)
logger.info('found %d I/O golden data, starting test ...', # type-shape inference and re-save
len(input_data) + len(output_data)) if save_inference_model:
for block in prog.blocks:
# DEBUG: reload test for Python code block_desc = block.desc
if basename.endswith('.py') and save_inference_model: for idx_op in range(block_desc.op_size()):
op_desc = block_desc.op(idx_op)
if op_desc.type() in ('feed', 'fetch'):
continue
op_desc.infer_var_type(block_desc)
op_desc.infer_shape(block_desc)
for var_name, var in block.vars.items():
var_desc = var.desc
if var_desc.type() != fluid.core.VarDesc.VarType.LOD_TENSOR:
continue
# WORKAROUND: dirty way to give dtype to partial-infered vars
# which could not be cleared!
try:
var.to_string(True)
except ValueError:
var_desc.set_dtype(fluid.core.VarDesc.VarType.FP32)
fluid.io.save_inference_model(fluid_model_dir, fluid.io.save_inference_model(fluid_model_dir,
input_data.keys(), input_names,
var_outs, var_outs,
exe, exe,
main_program=prog, main_program=prog,
...@@ -112,8 +140,12 @@ def validate(fluid_model_filename, ...@@ -112,8 +140,12 @@ def validate(fluid_model_filename,
fluid.io.load_inference_model(fluid_model_dir, exe) fluid.io.load_inference_model(fluid_model_dir, exe)
logger.info('model re-load passed') logger.info('model re-load passed')
if not golden_data_filename:
return True
# execute # execute
outputs = exe.run(prog, feed=input_data, fetch_list=out_names) outputs = exe.run(prog, feed=input_data,
fetch_list=out_names) # out_names can be vars
logger.info('execution passed') logger.info('execution passed')
# validate # validate
...@@ -134,11 +166,10 @@ def validate(fluid_model_filename, ...@@ -134,11 +166,10 @@ def validate(fluid_model_filename,
logger.info('accuracy passed') logger.info('accuracy passed')
else: else:
logger.info('accuracy not passed') logger.info('accuracy not passed')
return passed return passed
if __name__ == '__main__': def main():
import argparse import argparse
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -160,6 +191,7 @@ if __name__ == '__main__': ...@@ -160,6 +191,7 @@ if __name__ == '__main__':
'--test_data', '--test_data',
'-t', '-t',
type=str, type=str,
default='',
help='I/O golden data for validation, e.g. test.npy, test.npz', help='I/O golden data for validation, e.g. test.npy, test.npz',
) )
parser.add_argument( parser.add_argument(
...@@ -175,19 +207,36 @@ if __name__ == '__main__': ...@@ -175,19 +207,36 @@ if __name__ == '__main__':
default=1e-2, default=1e-2,
help='assertion relative tolerance for validation', help='assertion relative tolerance for validation',
) )
parser.add_argument(
'--infer_inputs',
'-i',
nargs='?',
default=None,
const='',
help=
'perform type-shape inference with given input names and re-save model',
)
args = parser.parse_args() args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s' logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_level = logging.DEBUG if args.debug else logging.INFO logging_level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(format=logging_format, level=logging_level) logging.basicConfig(format=logging_format, level=logging_level)
debug = args.debug # debug = args.debug
fluid_model_filename = args.model[0] fluid_model_filename = args.model[0]
golden_data_filename = args.test_data golden_data_filename = args.test_data
atol, rtol = args.atol, args.rtol atol, rtol = args.atol, args.rtol
save_inference_model = args.infer_inputs is not None
inference_input_names = args.infer_inputs.split(
',') if args.infer_inputs else None
validate(fluid_model_filename, validate(fluid_model_filename,
golden_data_filename, golden_data_filename=golden_data_filename,
atol=atol, atol=atol,
rtol=rtol, rtol=rtol,
save_inference_model=debug) save_inference_model=save_inference_model,
inference_input_names=inference_input_names)
if __name__ == '__main__':
main()
...@@ -44,6 +44,8 @@ def irepr(obj, to='_'): ...@@ -44,6 +44,8 @@ def irepr(obj, to='_'):
def flatten_list(obj, out=None): def flatten_list(obj, out=None):
assert isinstance(obj, list), 'list type required'
if out is None: if out is None:
out = type(obj)() out = type(obj)()
for item in obj: for item in obj:
...@@ -56,12 +58,12 @@ def flatten_list(obj, out=None): ...@@ -56,12 +58,12 @@ def flatten_list(obj, out=None):
def make_attr_name(name): def make_attr_name(name):
""" """
make a valid code name for ParamAttr make a valid code name for ParamAttr
""" """
if name == '': assert name != '', 'name should not be empty'
raise ValueError('name should not be empty')
for s in ' \\|/:': # for s in ' \\|/:-': #
name = name.replace(s, '_') name = name.replace(s, '_')
if not name.startswith('_'): if not name.startswith('_'):
name = '_' + name name = '_' + name
...@@ -70,8 +72,8 @@ def make_attr_name(name): ...@@ -70,8 +72,8 @@ def make_attr_name(name):
class Program(object): class Program(object):
""" """
fluid Python code and ProgramDesc wrapper fluid Python code and ProgramDesc wrapper
""" """
DTYPE_TO_FRAMEWORK_DTYPE = { DTYPE_TO_FRAMEWORK_DTYPE = {
'bool': framework_pb2.VarType.BOOL, 'bool': framework_pb2.VarType.BOOL,
...@@ -88,8 +90,8 @@ class Program(object): ...@@ -88,8 +90,8 @@ class Program(object):
@staticmethod @staticmethod
def Dtype(dtype): def Dtype(dtype):
""" """
convert dtype to fulid framework dtype convert dtype to fulid framework dtype
""" """
dtype = np.dtype(dtype).name dtype = np.dtype(dtype).name
return Program.DTYPE_TO_FRAMEWORK_DTYPE[dtype] return Program.DTYPE_TO_FRAMEWORK_DTYPE[dtype]
...@@ -97,8 +99,8 @@ class Program(object): ...@@ -97,8 +99,8 @@ class Program(object):
@staticmethod @staticmethod
def OpDescVars(vals, *keys): def OpDescVars(vals, *keys):
""" """
make (OpDesc.Var)s make (OpDesc.Var)s
""" """
od_vars = [] od_vars = []
for idx, key in enumerate(keys): for idx, key in enumerate(keys):
...@@ -112,8 +114,8 @@ class Program(object): ...@@ -112,8 +114,8 @@ class Program(object):
@staticmethod @staticmethod
def OpDescAttrs(attrs): def OpDescAttrs(attrs):
""" """
make (OpDesc.Attr)s make (OpDesc.Attr)s
""" """
od_attrs = [] od_attrs = []
for key, value in attrs.items(): for key, value in attrs.items():
...@@ -178,8 +180,8 @@ class Program(object): ...@@ -178,8 +180,8 @@ class Program(object):
def Code(self, code): def Code(self, code):
""" """
add Python code add Python code
""" """
if self.code_mutable: if self.code_mutable:
self.codes.append(code) self.codes.append(code)
...@@ -190,16 +192,16 @@ class Program(object): ...@@ -190,16 +192,16 @@ class Program(object):
output_val_keys=None, output_val_keys=None,
attrs=None): attrs=None):
""" """
add OpDesc add OpDesc
""" """
desc = framework_pb2.OpDesc() desc = framework_pb2.OpDesc()
desc.type = op_type desc.type = op_type
if input_val_keys is not None: if input_val_keys:
desc.inputs.extend(self.OpDescVars(*input_val_keys)) desc.inputs.extend(self.OpDescVars(*input_val_keys))
if output_val_keys is not None: if output_val_keys:
desc.outputs.extend(self.OpDescVars(*output_val_keys)) desc.outputs.extend(self.OpDescVars(*output_val_keys))
if attrs is not None: if attrs:
desc.attrs.extend(self.OpDescAttrs(attrs)) desc.attrs.extend(self.OpDescAttrs(attrs))
self.op_descs.append(desc) self.op_descs.append(desc)
return desc return desc
...@@ -210,8 +212,8 @@ class Program(object): ...@@ -210,8 +212,8 @@ class Program(object):
value_info=None, value_info=None,
remove_batch=None): remove_batch=None):
""" """
add VarDesc, add VarDesc,
""" """
assert var_name not in self.var_descs, 'var naming conflicted' assert var_name not in self.var_descs, 'var naming conflicted'
...@@ -220,13 +222,16 @@ class Program(object): ...@@ -220,13 +222,16 @@ class Program(object):
var_desc.persistable = persistable var_desc.persistable = persistable
var_desc.type.type = framework_pb2.VarType.LOD_TENSOR var_desc.type.type = framework_pb2.VarType.LOD_TENSOR
self.var_descs[var_name] = var_desc self.var_descs[var_name] = var_desc
if value_info: if value_info:
self.VarTypeInfo(var_name, value_info, remove_batch=remove_batch) self.VarTypeShapeInfo(var_name,
value_info,
remove_batch=remove_batch)
def Op(self, domain, op_type, *args, **kwargs): def Op(self, domain, op_type, *args, **kwargs):
""" """
convert an ONNX op and add it to program convert an ONNX op and add it to program
""" """
if domain != '': # TODO: symbolic file routing by domain if domain != '': # TODO: symbolic file routing by domain
raise ValueError('only default domain supported') raise ValueError('only default domain supported')
...@@ -242,8 +247,8 @@ class Program(object): ...@@ -242,8 +247,8 @@ class Program(object):
def IntermediateOp(self, domain, op_type, *args, **kwargs): def IntermediateOp(self, domain, op_type, *args, **kwargs):
""" """
convert an intermediate ONNX op declaring in desc program only convert an intermediate ONNX op declaring in desc program only
""" """
code_mutable = self.code_mutable code_mutable = self.code_mutable
self.code_mutable = False self.code_mutable = False
...@@ -255,10 +260,10 @@ class Program(object): ...@@ -255,10 +260,10 @@ class Program(object):
else: else:
self.code_mutable = code_mutable self.code_mutable = code_mutable
def VarTypeInfo(self, var_name, value_info, remove_batch=None): def VarTypeShapeInfo(self, var_name, value_info, remove_batch=None):
"""
set value_info for var
""" """
set value_info for var
"""
if var_name not in self.var_descs: if var_name not in self.var_descs:
return return
...@@ -284,8 +289,8 @@ class Program(object): ...@@ -284,8 +289,8 @@ class Program(object):
class Writer(object): class Writer(object):
""" """
fluid code and desc writter fluid code and desc writter
""" """
# CODE_INDENT = ' ' * 4 # CODE_INDENT = ' ' * 4
CODE_INDENT = '\t' CODE_INDENT = '\t'
...@@ -293,8 +298,8 @@ class Writer(object): ...@@ -293,8 +298,8 @@ class Writer(object):
@staticmethod @staticmethod
def header_code(func_name, info=''): def header_code(func_name, info=''):
""" """
Python header codes Python header codes
""" """
codes = [] codes = []
codes.append('"""') codes.append('"""')
...@@ -315,8 +320,8 @@ class Writer(object): ...@@ -315,8 +320,8 @@ class Writer(object):
def emit_op(prog, name, domain, op_type, inputs, outputs, attrs, def emit_op(prog, name, domain, op_type, inputs, outputs, attrs,
value_infos, *args, **kwargs): value_infos, *args, **kwargs):
""" """
emit an ONNX op into program emit an ONNX op into program
""" """
prog.Code('# {}, {}::{}: {} -> {}, {}'.format(name, domain, op_type, prog.Code('# {}, {}::{}: {} -> {}, {}'.format(name, domain, op_type,
inputs, outputs, inputs, outputs,
...@@ -334,8 +339,8 @@ class Writer(object): ...@@ -334,8 +339,8 @@ class Writer(object):
@staticmethod @staticmethod
def emit_param(prog, name, value_info): def emit_param(prog, name, value_info):
""" """
emit an ONNX weight into program emit an ONNX weight into program
""" """
if value_info.get('embeded_as', []): if value_info.get('embeded_as', []):
var_names = value_info['embeded_as'] var_names = value_info['embeded_as']
...@@ -359,8 +364,8 @@ class Writer(object): ...@@ -359,8 +364,8 @@ class Writer(object):
@staticmethod @staticmethod
def emit_inputs(prog, names, value_infos, remove_batch=None): def emit_inputs(prog, names, value_infos, remove_batch=None):
""" """
emit ONNX inputs into program emit ONNX inputs into program
""" """
for idx, name in enumerate(names): for idx, name in enumerate(names):
var_name = make_var_name(name) var_name = make_var_name(name)
...@@ -396,8 +401,8 @@ class Writer(object): ...@@ -396,8 +401,8 @@ class Writer(object):
@staticmethod @staticmethod
def emit_outputs(prog, names): #, value_infos def emit_outputs(prog, names): #, value_infos
""" """
emit ONNX outputs into program emit ONNX outputs into program
""" """
code = 'return ' code = 'return '
for idx, name in enumerate(names): for idx, name in enumerate(names):
...@@ -416,8 +421,8 @@ class Writer(object): ...@@ -416,8 +421,8 @@ class Writer(object):
@staticmethod @staticmethod
def add_codes(codes, others, indent): def add_codes(codes, others, indent):
""" """
flatten codes in program flatten codes in program
""" """
for code in flatten_list(others): for code in flatten_list(others):
codes.append(Writer.CODE_INDENT * indent + code) codes.append(Writer.CODE_INDENT * indent + code)
...@@ -426,11 +431,10 @@ class Writer(object): ...@@ -426,11 +431,10 @@ class Writer(object):
@staticmethod @staticmethod
def write_weight(weight, filename): def write_weight(weight, filename):
""" """
write single weight in fluid desc write single weight in fluid desc
""" """
if not isinstance(weight, np.ndarray): assert isinstance(weight, np.ndarray), 'weight is not an ndarray'
raise TypeError('weight is not an ndarray')
tensor_desc = framework_pb2.VarType.TensorDesc() tensor_desc = framework_pb2.VarType.TensorDesc()
tensor_desc.data_type = Program.Dtype(weight.dtype) tensor_desc.data_type = Program.Dtype(weight.dtype)
...@@ -448,12 +452,11 @@ class Writer(object): ...@@ -448,12 +452,11 @@ class Writer(object):
@staticmethod @staticmethod
def write_weights(weights, save_dir): def write_weights(weights, save_dir):
""" """
write multiple weights in each fluid desc write multiple weights in each fluid desc
""" """
for name, weight in weights.items(): for name, weight in weights.items():
if not isinstance(weights, dict): assert isinstance(weights, dict), 'dict type weights required'
raise TypeError('dict type weights required')
var_name = make_var_name(name) var_name = make_var_name(name)
filename = os.path.join(save_dir, var_name) filename = os.path.join(save_dir, var_name)
...@@ -463,8 +466,8 @@ class Writer(object): ...@@ -463,8 +466,8 @@ class Writer(object):
@staticmethod @staticmethod
def write_code_file(filename, header_code, *body_codes): def write_code_file(filename, header_code, *body_codes):
""" """
write Python code to file write Python code to file
""" """
codes = [] codes = []
Writer.add_codes(codes, header_code, 0) Writer.add_codes(codes, header_code, 0)
...@@ -481,8 +484,8 @@ class Writer(object): ...@@ -481,8 +484,8 @@ class Writer(object):
@staticmethod @staticmethod
def write_desc_file(filename, op_descs, var_descs): def write_desc_file(filename, op_descs, var_descs):
""" """
write desc program to file write desc program to file
""" """
prog_desc = framework_pb2.ProgramDesc() prog_desc = framework_pb2.ProgramDesc()
block_desc = prog_desc.blocks.add() block_desc = prog_desc.blocks.add()
......
...@@ -54,8 +54,8 @@ zip_safe = True ...@@ -54,8 +54,8 @@ zip_safe = True
[options.entry_points] [options.entry_points]
console_scripts = console_scripts =
onnx2fluid = onnx2fluid.__main__ onnx2fluid = onnx2fluid.__main__
onnx2fluid_convert = onnx2fluid.conversion onnx2fluid_convert = onnx2fluid.conversion:main
onnx2fluid_validate = onnx2fluid.validation onnx2fluid_validate = onnx2fluid.validation:main
# 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下 # 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下
# 仅支持文件,不支持目录,但可以使用通配 # 仅支持文件,不支持目录,但可以使用通配
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册