提交 492e9661 编写于 作者: M Macrobull

add new ONNX polish_model

上级 826481c4
...@@ -33,7 +33,7 @@ def main(**kwargs): ...@@ -33,7 +33,7 @@ def main(**kwargs):
from .conversion import convert from .conversion import convert
logger = logging.getLogger('onnx2fluid') logger = logging.getLogger('onnx2fluid')
debug = kwargs.get('debug', False) # debug = kwargs.get('debug', False)
# prepare arguments # prepare arguments
filename = kwargs.pop('model')[0] filename = kwargs.pop('model')[0]
...@@ -65,8 +65,7 @@ def main(**kwargs): ...@@ -65,8 +65,7 @@ def main(**kwargs):
from .validation import validate from .validation import validate
save_inference_model = infer_inputs is not None save_inference_model = infer_inputs is not None
inference_input_names = infer_inputs.split( inference_input_names = infer_inputs and infer_inputs.split(',')
',') if infer_inputs else None
logger.info('starting validation on desc ...') logger.info('starting validation on desc ...')
passed &= validate(shutil.os.path.join(save_dir, '__model__'), passed &= validate(shutil.os.path.join(save_dir, '__model__'),
...@@ -85,7 +84,7 @@ def main(**kwargs): ...@@ -85,7 +84,7 @@ def main(**kwargs):
**kwargs) **kwargs)
if not passed: if not passed:
logger.error('validation failed, exit') logger.fatal('validation failed, exit')
return return
# create zip file # create zip file
......
...@@ -34,15 +34,12 @@ def convert(onnx_model_filename, ...@@ -34,15 +34,12 @@ def convert(onnx_model_filename,
from onnx.checker import ValidationError from onnx.checker import ValidationError
from onnx.checker import check_model from onnx.checker import check_model
from onnx.utils import polish_model
from onnx.version_converter import convert_version from onnx.version_converter import convert_version
from .onnx_utils import DEFAULT_OP_DOMAIN from .onnx_utils import DEFAULT_OP_DOMAIN
from .onnx_utils import graph_ops, graph_weights from .onnx_utils import graph_ops, graph_weights
from .onnx_utils import inferred_model_value_info from .onnx_utils import inferred_model_value_info
from .onnx_utils import optimize_model_skip_op_for_inference from .onnx_utils import polish_model
from .onnx_utils import optimize_model_strip_initializer
from .onnx_utils import optimize_model_cast, optimize_model_slice
from .writer import Program, Writer from .writer import Program, Writer
from .writer import make_var_name from .writer import make_var_name
...@@ -56,14 +53,12 @@ def convert(onnx_model_filename, ...@@ -56,14 +53,12 @@ def convert(onnx_model_filename,
logger.info('checking model ...') logger.info('checking model ...')
check_model(onnx_model) check_model(onnx_model)
if onnx_opset_version is None: # WORKAROUND: RuntimeError: No Adapter For OP if onnx_opset_version is None: # WORKAROUND: RuntimeError: No Adapter For OP
logger.debug('assumed opset version: %d',
DEFAULT_ONNX_OPSET_VERSION)
logger.warning( logger.warning(
'opset conversion skipped for onnx_opset_pedantic is OFF') 'opset conversion skipped for onnx_opset_pedantic is OFF')
logger.info('assumed opset version: %d', DEFAULT_ONNX_OPSET_VERSION)
else: else:
logger.debug('using opset version: %d', onnx_opset_version) logger.info('using opset version: %d', onnx_opset_version)
onnx_model = convert_version(onnx_model, onnx_opset_version) onnx_model = convert_version(onnx_model, onnx_opset_version)
onnx_model = polish_model(onnx_model)
except ValidationError as e: except ValidationError as e:
if onnx_opset_pedantic: if onnx_opset_pedantic:
raise e raise e
...@@ -75,10 +70,7 @@ def convert(onnx_model_filename, ...@@ -75,10 +70,7 @@ def convert(onnx_model_filename,
# onnx model optimization # onnx model optimization
logger.info('model has %d ops', len(onnx_model.graph.node)) logger.info('model has %d ops', len(onnx_model.graph.node))
logger.info('optimizing model ...') logger.info('optimizing model ...')
onnx_model = optimize_model_skip_op_for_inference(onnx_model) onnx_model = polish_model(onnx_model)
onnx_model = optimize_model_strip_initializer(onnx_model)
onnx_model = optimize_model_cast(onnx_model)
onnx_model = optimize_model_slice(onnx_model)
# prepare filesystem # prepare filesystem
shutil.rmtree(save_dir, ignore_errors=True) shutil.rmtree(save_dir, ignore_errors=True)
...@@ -87,9 +79,8 @@ def convert(onnx_model_filename, ...@@ -87,9 +79,8 @@ def convert(onnx_model_filename,
# DEBUG: # DEBUG:
if debug: if debug:
model = onnx.shape_inference.infer_shapes(onnx_model)
debug_model_filename, _ = shutil.os.path.splitext(onnx_model_filename) debug_model_filename, _ = shutil.os.path.splitext(onnx_model_filename)
onnx.save(model, debug_model_filename + '.optimized_and_inffered.onnx') onnx.save(onnx_model, debug_model_filename + '.polished.onnx')
# I/O instances # I/O instances
onnx_graph = onnx_model.graph onnx_graph = onnx_model.graph
...@@ -141,11 +132,11 @@ def convert(onnx_model_filename, ...@@ -141,11 +132,11 @@ def convert(onnx_model_filename,
logger.info('%d ops in, %d ops out', len(onnx_graph.node), logger.info('%d ops in, %d ops out', len(onnx_graph.node),
len(fluid_program.op_descs)) len(fluid_program.op_descs))
# type-shape inference # type-shape info copy
for name, value_info in graph_value_infos.items(): for name, value_info in graph_value_infos.items():
var_name = make_var_name(name) var_name = make_var_name(name)
fluid_program.VarTypeShapeInfo(var_name, value_info, fluid_program.VarTypeShapeInfo(var_name, value_info,
remove_batch=False) # shape-infer only remove_batch=False) #
bad_var_names = [] bad_var_names = []
for var_name, var_desc in fluid_program.var_descs.items(): for var_name, var_desc in fluid_program.var_descs.items():
if not var_desc.type.lod_tensor.HasField('tensor'): if not var_desc.type.lod_tensor.HasField('tensor'):
...@@ -155,8 +146,8 @@ def convert(onnx_model_filename, ...@@ -155,8 +146,8 @@ def convert(onnx_model_filename,
', '.join(bad_var_names[:5])) ', '.join(bad_var_names[:5]))
logger.warning('this causes little problem for PaddlePaddle, ' logger.warning('this causes little problem for PaddlePaddle, '
'but Paddle Mobile may not infer correctly') 'but Paddle Mobile may not infer correctly')
logger.warning('please consider running onnx2fluid.validation with -i ' logger.warning('please consider running validation with -i '
'to invoke PaddlePaddle type-shape inference') 'to invoke type-shape inference in PaddlePaddle')
# weight writer # weight writer
for name, weight in graph_weights(onnx_graph): for name, weight in graph_weights(onnx_graph):
......
...@@ -11,9 +11,11 @@ from __future__ import division ...@@ -11,9 +11,11 @@ from __future__ import division
import logging import logging
import numpy as np import numpy as np
import onnx import onnx
import onnx.optimizer as optimizer
from collections import OrderedDict as Dict # as default dict from collections import OrderedDict as Dict # as default dict
from onnx.helper import get_attribute_value, make_attribute from onnx.checker import check_model
from onnx.helper import get_attribute_value, make_attribute, strip_doc_string
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
from onnx.numpy_helper import to_array from onnx.numpy_helper import to_array
from onnx.shape_inference import infer_shapes from onnx.shape_inference import infer_shapes
...@@ -23,14 +25,16 @@ logger = logging.getLogger(__name__) ...@@ -23,14 +25,16 @@ logger = logging.getLogger(__name__)
__all__ = [ __all__ = [
'print_pb_structure', 'print_pb_structure',
'build_value_refs', 'build_value_refs',
'tensor_dtype',
'tensor_shape',
'node_attrs', 'node_attrs',
'node_topo', 'node_topo',
'node_iter', 'node_iter',
'tensor_dtype',
'tensor_shape',
'graph_ops', 'graph_ops',
'graph_weights', 'graph_weights',
'inferred_model_value_info', 'inferred_model_value_info',
'polish_model',
'polish_and_save',
'optimize_model_skip_op_for_inference', 'optimize_model_skip_op_for_inference',
'optimize_model_strip_initializer', 'optimize_model_strip_initializer',
'optimize_model_cast', 'optimize_model_cast',
...@@ -110,7 +114,7 @@ def tensor_shape(tensor): ...@@ -110,7 +114,7 @@ def tensor_shape(tensor):
get ONNX tensor shape get ONNX tensor shape
""" """
return [dim.dim_value for dim in tensor.type.tensor_type.shape.dim] return tuple([dim.dim_value for dim in tensor.type.tensor_type.shape.dim])
def node_attrs(node): def node_attrs(node):
...@@ -195,10 +199,7 @@ def node_iter(nodes, indices=None): ...@@ -195,10 +199,7 @@ def node_iter(nodes, indices=None):
generator for ONNX node graph with given indices generator for ONNX node graph with given indices
""" """
if indices is None: for index in indices or range(len(nodes)):
indices = range(len(nodes))
for index in indices:
node = nodes[index] node = nodes[index]
name = node.name name = node.name
domain = node.domain domain = node.domain
...@@ -306,6 +307,48 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs): ...@@ -306,6 +307,48 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs):
return processed return processed
def polish_model(model, extras=True):
"""
polish_model enhanced for inference
"""
check_model(model)
strip_doc_string(model)
passes = optimizer.get_available_passes()
passes = list(filter(lambda name: not name.startswith('split_'), passes)) #
logger.debug('builtin optimizations to perform in ONNX:\n\t%s', passes)
model = optimizer.optimize(model, passes=passes)
if extras:
for optimize in (
optimize_model_skip_op_for_inference,
optimize_model_strip_initializer,
optimize_model_cast,
optimize_model_slice,
):
model = optimize(model)
model = infer_shapes(model)
check_model(model)
return model
def polish_and_save(model_filename,
suffix='.polished',
save_filename=None,
*args,
**kwargs):
"""
run polish_model and save
"""
model = onnx.load(model_filename)
model = polish_model(model, *args, **kwargs)
save_filename = save_filename or model_filename.replace(
'.onnx', suffix + '.onnx')
onnx.save(model, save_filename)
logger.info('polished model saved to: %s', save_filename)
return save_filename
def optimize_model_skip_op_for_inference(model, op_list=None): def optimize_model_skip_op_for_inference(model, op_list=None):
""" """
skip ops can be bypassed for inference skip ops can be bypassed for inference
...@@ -326,7 +369,7 @@ def optimize_model_skip_op_for_inference(model, op_list=None): ...@@ -326,7 +369,7 @@ def optimize_model_skip_op_for_inference(model, op_list=None):
if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''): if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
continue continue
op_type = node.op_type op_type = node.op_type
if not (op_type in op_list): if op_type not in op_list:
continue continue
if op_type in ('Dropout', ): if op_type in ('Dropout', ):
...@@ -590,22 +633,16 @@ if __name__ == '__main__': ...@@ -590,22 +633,16 @@ if __name__ == '__main__':
level=logging.DEBUG, level=logging.DEBUG,
) )
from onnx.checker import check_model
from onnx.utils import polish_model
from onnx.version_converter import convert_version from onnx.version_converter import convert_version
model = onnx.load('../examples/t1.onnx') model = onnx.load('/tmp/export.onnx')
print_pb_structure(model, loop_iterative=False) print_pb_structure(model, loop_iterative=False)
check_model(model) check_model(model)
model = convert_version(model, 9) model = convert_version(model, 9)
model = optimize_model_skip_op_for_inference(model)
model = optimize_model_strip_initializer(model)
model = optimize_model_cast(model)
model = optimize_model_slice(model)
model = polish_model(model) model = polish_model(model)
onnx.save(model, '/tmp/optimized.onnx') onnx.save(model, '/tmp/export.optimized.onnx')
graph = model.graph graph = model.graph
value_info = inferred_model_value_info(model) value_info = inferred_model_value_info(model)
...@@ -617,23 +654,23 @@ if __name__ == '__main__': ...@@ -617,23 +654,23 @@ if __name__ == '__main__':
logger.info('ops:') logger.info('ops:')
for name, domain, op_type, _, _, attrs in graph_ops(graph, topo='forward'): for name, domain, op_type, _, _, attrs in graph_ops(graph, topo='forward'):
logger.info('%s %s::%s: %s', name, domain, op_type, attrs) logger.info('- \t%s %s::%s: %s', name, domain, op_type, attrs)
logger.info('weights:') logger.info('weights:')
for name, array in graph_weights(graph): for name, array in graph_weights(graph):
weights.append(name) weights.append(name)
logger.info('%s: %s', name, array.shape) logger.info('- \t%s: %s', name, array.shape)
logger.info('inputs:') logger.info('inputs:')
external_inputs = [] external_inputs = []
for name in inputs: for name in inputs:
if name not in weights: if name not in weights:
external_inputs.append(name) external_inputs.append(name)
logger.info('%s: %s', name, value_info[name]['shape']) logger.info('- \t%s: %s', name, value_info[name]['shape'])
logger.info('outputs:') logger.info('outputs:')
external_outputs = [] external_outputs = []
for name in outputs: for name in outputs:
if name not in weights: if name not in weights:
external_outputs.append(name) external_outputs.append(name)
logger.info('%s: %s', name, value_info[name]['shape']) logger.info('- \t%s: %s', name, value_info[name]['shape'])
...@@ -203,8 +203,7 @@ def _check_embeddable(value_infos, *val_names): ...@@ -203,8 +203,7 @@ def _check_embeddable(value_infos, *val_names):
keyword = 'get_weight' keyword = 'get_weight'
for val_name in val_names: for val_name in val_names:
if keyword not in value_infos[val_name]: if keyword not in value_infos[val_name]:
_logger.warning('parameter %s not embeddable for some ops', _logger.warning('parameter %s not embeddable', val_name)
val_name)
return False return False
return True return True
...@@ -240,9 +239,9 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -240,9 +239,9 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
fluid_attrs = default_attrs.copy() fluid_attrs = default_attrs.copy()
fluid_attrs.update(mapped_attrs) # as new attrs fluid_attrs.update(mapped_attrs) # as new attrs
val_inps = inputs if input_perm is None else map(lambda i: inputs[i], val_inps = inputs if input_perm is None else map(inputs.__getitem__,
input_perm) input_perm)
val_outs = outputs if output_perm is None else map(lambda i: outputs[i], val_outs = outputs if output_perm is None else map(outputs.__getitem__,
output_perm) output_perm)
var_inps = [_make_var_name(val) for val in val_inps] var_inps = [_make_var_name(val) for val in val_inps]
var_outs = [_make_var_name(val) for val in val_outs] var_outs = [_make_var_name(val) for val in val_outs]
...@@ -578,7 +577,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''): ...@@ -578,7 +577,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
1] == 1, 'only scale on (NC)HW supported' 1] == 1, 'only scale on (NC)HW supported'
assert scales[2] == scales[ assert scales[2] == scales[
3], 'only aspect-ratio-invariant scale supported' 3], 'only aspect-ratio-invariant scale supported'
scale = None if scales is None else scales[2] scale = scales and scales[2]
# try input shape # try input shape
if scale is None: if scale is None:
assert out_shape_, 'neither scales nor output shape is available' assert out_shape_, 'neither scales nor output shape is available'
...@@ -717,6 +716,10 @@ def BatchNormalization(prog, ...@@ -717,6 +716,10 @@ def BatchNormalization(prog,
if embed_params: if embed_params:
embed_params = _check_embeddable(value_infos, val_scale, val_b, embed_params = _check_embeddable(value_infos, val_scale, val_b,
val_mean, val_var) val_mean, val_var)
if not embed_params and name:
_logger.warning('for op %s(%s -> BatchNormalization -> %s)', name,
inputs, outputs)
_logger.warning('broken Python code will be generated')
if embed_params: if embed_params:
assert name != '' assert name != ''
var_scale = name + '.w_0' var_scale = name + '.w_0'
...@@ -875,7 +878,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -875,7 +878,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
if shape is None: if shape is None:
shape = list(value.shape) shape = list(value.shape)
_logger.warning( _logger.warning(
'in (Constant -> %s): ' 'in op (Constant -> %s): '
'attribute "shape" of %s not inferred, ' 'attribute "shape" of %s not inferred, '
'using value as 1-D tensor may lead to fails', outputs, val_output) 'using value as 1-D tensor may lead to fails', outputs, val_output)
...@@ -986,6 +989,10 @@ def Conv(prog, ...@@ -986,6 +989,10 @@ def Conv(prog,
if embed_params: if embed_params:
embed_params = (_check_embeddable(value_infos, val_w) and not has_bias embed_params = (_check_embeddable(value_infos, val_w) and not has_bias
or _check_embeddable(value_infos, val_b)) or _check_embeddable(value_infos, val_b))
if not embed_params and name:
_logger.warning('for op %s(%s -> Conv -> %s)', name, inputs,
outputs)
_logger.warning('broken Python code will be generated')
if embed_params: if embed_params:
assert name != '' assert name != ''
var_w = name + '.w_0' var_w = name + '.w_0'
...@@ -1099,6 +1106,10 @@ def ConvTranspose(prog, ...@@ -1099,6 +1106,10 @@ def ConvTranspose(prog,
if embed_params: if embed_params:
embed_params = (_check_embeddable(value_infos, val_w) and not has_bias embed_params = (_check_embeddable(value_infos, val_w) and not has_bias
or _check_embeddable(value_infos, val_b)) or _check_embeddable(value_infos, val_b))
if not embed_params and name:
_logger.warning('for op %s(%s -> ConvTranspose -> %s)', name,
inputs, outputs)
_logger.warning('broken Python code will be generated')
if embed_params: if embed_params:
assert name != '' assert name != ''
var_w = name + '.w_0' var_w = name + '.w_0'
...@@ -1167,23 +1178,6 @@ def ConvTranspose(prog, ...@@ -1167,23 +1178,6 @@ def ConvTranspose(prog,
prog.VarDesc(var_y) prog.VarDesc(var_y)
# should not appear
#def Dropout(
# prog, inputs, outputs, value_infos,
# *args, **kwargs):
# """
# onnx::Dropout-7:9
# """
#
# val_data, = inputs
# val_output, = outputs[:1]
#
# _assign(prog,
# dict([(val_output, val_data)]),
# value_infos,
# )
def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
""" """
onnx::Gemm-9: onnx::Gemm-9:
...@@ -1236,7 +1230,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1236,7 +1230,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
if vm_dtype is None: if vm_dtype is None:
vm_dtype = _np.dtype('float32') vm_dtype = _np.dtype('float32')
_logger.warning( _logger.warning(
'in %s(%s -> Gemm -> %s): ' 'in op %s(%s -> Gemm -> %s): '
'attribute "beta" seems to be an interger, ' 'attribute "beta" seems to be an interger, '
'however dtype can not be inferred, ' 'however dtype can not be inferred, '
'still use float32', name, inputs, outputs) 'still use float32', name, inputs, outputs)
...@@ -1425,6 +1419,10 @@ def PRelu(prog, ...@@ -1425,6 +1419,10 @@ def PRelu(prog,
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params: if embed_params:
embed_params = _check_embeddable(value_infos, val_slope) embed_params = _check_embeddable(value_infos, val_slope)
if not embed_params and name:
_logger.warning('for op %s(%s -> PRelu -> %s)', name, inputs,
outputs)
_logger.warning('broken Python code will be generated')
if embed_params: if embed_params:
assert name != '' assert name != ''
var_slope = name + '.w_0' var_slope = name + '.w_0'
...@@ -1487,7 +1485,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1487,7 +1485,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
if shape is None: if shape is None:
shape = [1, -1] # who knows shape = [1, -1] # who knows
_logger.warning( _logger.warning(
'in %s(%s -> Reshape -> %s): ' 'in op %s(%s -> Reshape -> %s): '
'input "shape" not inferred, use [1, -1] as dummy value, ' 'input "shape" not inferred, use [1, -1] as dummy value, '
'the behavior of Paddle fluid maybe undefined', name, inputs, 'the behavior of Paddle fluid maybe undefined', name, inputs,
outputs) outputs)
......
...@@ -9,22 +9,50 @@ Created on Fri Mar 22 11:22:46 2019 ...@@ -9,22 +9,50 @@ Created on Fri Mar 22 11:22:46 2019
import numpy as np import numpy as np
import torch import torch
from collections import OrderedDict as Dict from collections import OrderedDict
from typing import (
TypeVar,
Any,
Generic,
Iterable,
List,
Mapping,
Optional,
Sequence,
Text,
Tuple,
Union,
)
__all__ = [
'export_data',
'export_onnx_with_validation',
]
def ensure_list(obj): my_dict = OrderedDict
KT = TypeVar('KT')
VT = TypeVar('VT')
class MyDict(my_dict, Generic[KT, VT]):
pass
def ensure_list(obj: Union[object, Sequence[object]]) -> List[object]:
if isinstance(obj, (list, tuple, set)): if isinstance(obj, (list, tuple, set)):
return list(obj) return list(obj)
return [obj] return [obj]
def ensure_tuple(obj): def ensure_tuple(obj: Union[object, Sequence[object]]) -> Tuple[object, ...]:
if isinstance(obj, (tuple, list, set)): if isinstance(obj, (tuple, list, set)):
return tuple(obj) return tuple(obj)
return (obj, ) return (obj, )
def flatten_list(obj, out=None): def flatten_list(obj: List[Union[object, List[object]]],
out: Optional[List[object]] = None) -> List[object]:
assert isinstance(obj, list), 'list type required' assert isinstance(obj, list), 'list type required'
if out is None: if out is None:
...@@ -37,21 +65,21 @@ def flatten_list(obj, out=None): ...@@ -37,21 +65,21 @@ def flatten_list(obj, out=None):
return out return out
def export_data(state_dict, prefix=''): def export_data(state_dict: Mapping[Text, Any], prefix: Text = '') -> None:
""" """
export binary data with meta text for raw C++ inference engines export binary data with meta text for raw C++ inference engines
""" """
def str_(obj): def str_(obj: object) -> Text:
if isinstance(obj, (tuple, list, set)): if isinstance(obj, (tuple, list, set)):
return str(obj)[1:-1].replace(' ', '') return str(obj)[1:-1].replace(' ', '')
return str(obj) return str(obj)
prefix_ = prefix + ('_' if prefix else '') prefix_ = prefix + ('_' if prefix else '')
fp = open('{}.txt'.format(prefix if prefix else 'meta'), 'w') fp = open('{}.txt'.format(prefix or 'meta'), 'w')
for key, value in state_dict.items(): for key, value in state_dict.items():
data = None data = None
if torch and torch.is_tensor(value): if torch.is_tensor(value):
data = value.data.cpu().numpy() data = value.data.cpu().numpy()
elif isinstance(value, np.ndarray): elif isinstance(value, np.ndarray):
data = value data = value
...@@ -64,30 +92,33 @@ def export_data(state_dict, prefix=''): ...@@ -64,30 +92,33 @@ def export_data(state_dict, prefix=''):
fp.close() fp.close()
def export_onnx_with_validation(model, def export_onnx_with_validation(
inputs, model: torch.nn.Module,
export_basepath, inputs: Sequence[Union[torch.Tensor, Sequence[object]]],
input_names=None, export_basepath: Text,
output_names=None, input_names: Optional[List[Text]] = None,
use_npz=True, output_names: Optional[List[Text]] = None,
*args, use_npz: bool = True,
**kwargs): *args,
**kwargs) -> Sequence[Union[torch.Tensor, Sequence[object]]]:
""" """
export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
""" """
is_tuple_or_list = lambda x: isinstance(x, (tuple, list)) is_tuple_or_list = lambda x: isinstance(x, (tuple, list))
def tensors_to_arrays(tensors): def tensors_to_arrays(tensors: Union[torch.Tensor, Iterable[
Union[torch.Tensor, Iterable[Any]]]], ) -> List[np.ndarray]:
if torch.is_tensor(tensors): if torch.is_tensor(tensors):
return tensors.data.cpu().numpy() return tensors.data.cpu().numpy()
arrays = [] return list(map(tensors_to_arrays, tensors))
for tensor in tensors:
arrays.append(tensors_to_arrays(tensor)) def zip_dict(
return arrays keys: Union[Iterable[Any], None],
values: Sequence[Union[Any, Sequence[Any]]],
def zip_dict(keys, values): ) -> MyDict[Text, Union[object, MyDict[Text, object]]]:
ret = Dict() keys = keys or range(len(values))
ret = my_dict()
for idx, (key, value) in enumerate(zip(keys, values)): for idx, (key, value) in enumerate(zip(keys, values)):
is_key_list = is_tuple_or_list(key) is_key_list = is_tuple_or_list(key)
is_value_list = is_tuple_or_list(value) is_value_list = is_tuple_or_list(value)
...@@ -102,19 +133,48 @@ def export_onnx_with_validation(model, ...@@ -102,19 +133,48 @@ def export_onnx_with_validation(model,
outputs = torch.onnx.export(model, outputs = torch.onnx.export(model,
torch_inputs, torch_inputs,
export_basepath + '.onnx', export_basepath + '.onnx',
input_names=flatten_list(input_names), input_names=(None if input_names is None else
output_names=flatten_list(output_names), flatten_list(input_names)),
output_names=(None if output_names is None else
flatten_list(output_names)),
*args, *args,
**kwargs) **kwargs)
if outputs is None: # WORKAROUND: for torch.onnx if outputs is None: # WORKAROUND: for torch.onnx
outputs = model(*inputs) training = kwargs.get('training', False)
with torch.onnx.set_training(model, training):
outputs = model(*inputs)
torch_outputs = ensure_tuple(outputs) torch_outputs = ensure_tuple(outputs)
inputs = zip_dict(input_names, tensors_to_arrays(torch_inputs)) inputs = zip_dict(input_names, tensors_to_arrays(torch_inputs))
outputs = zip_dict(output_names, tensors_to_arrays(torch_outputs)) outputs = zip_dict(output_names, tensors_to_arrays(torch_outputs))
if use_npz: if use_npz:
np.savez(export_basepath + '.npz', inputs=inputs, outputs=outputs) np.savez(
export_basepath + '.npz',
inputs=inputs,
outputs=outputs,
)
else: else:
np.save(export_basepath + '.npy', np.save(export_basepath + '.npy',
np.array(Dict(inputs=inputs, outputs=outputs))) np.asarray(my_dict(inputs=inputs, outputs=outputs)),
allow_pickle=True)
return torch_outputs return torch_outputs
if __name__ == '__main__':
from torchvision.models import resnet18 as net
model = net()
xb = torch.rand((1, 3, 224, 224))
export_onnx_with_validation(
model,
(xb, ),
'/tmp/export',
input_names=[
'image',
],
output_names=[
'prob',
],
use_npz=True,
)
...@@ -8,6 +8,13 @@ Created on Fri Mar 22 12:17:19 2019 ...@@ -8,6 +8,13 @@ Created on Fri Mar 22 12:17:19 2019
import importlib, logging, os, sys import importlib, logging, os, sys
logger = logging.getLogger(__name__)
__all__ = [
'fluid_prog_shape_infer',
'validate',
]
def flatten_dict(obj, out=None): def flatten_dict(obj, out=None):
assert isinstance(obj, dict), 'dict type required' assert isinstance(obj, dict), 'dict type required'
...@@ -29,6 +36,42 @@ def ensure_list(obj): ...@@ -29,6 +36,42 @@ def ensure_list(obj):
return [obj] return [obj]
def fluid_prog_shape_infer(prog):
"""
additional type-shape inference for fluid program
"""
import paddle.fluid as fluid
assert isinstance(prog, fluid.framework.Program)
logger.info('performing type-shape inference ...')
for block in prog.blocks:
block_desc = block.desc
for idx_op in range(block_desc.op_size()):
op_desc = block_desc.op(idx_op)
if op_desc.type() in ('feed', 'fetch'):
continue
op_desc.infer_var_type(block_desc)
op_desc.infer_shape(block_desc)
for var_name, var in block.vars.items():
var_desc = var.desc
if var_desc.type() != fluid.core.VarDesc.VarType.LOD_TENSOR:
continue
# WORKAROUND: dirty way to give dtype to partial-infered vars
# which could not be cleared!
try:
var.to_string(True)
except ValueError:
var_desc.set_dtype(fluid.core.VarDesc.VarType.FP32)
logger.debug('dtype of var %s not inferred, float32 assumed',
var_name)
def validate(fluid_model_filename, def validate(fluid_model_filename,
golden_data_filename='', golden_data_filename='',
atol=1e-3, atol=1e-3,
...@@ -53,12 +96,12 @@ def validate(fluid_model_filename, ...@@ -53,12 +96,12 @@ def validate(fluid_model_filename,
# load model # load model
fluid_model_dir, basename = os.path.split(fluid_model_filename) fluid_model_dir, basename = os.path.split(fluid_model_filename)
if basename == '__model__': # is desc program if basename == '__model__': # is desc program
logger.debug('using desc file %s', basename) logger.info('using desc file %s', basename)
prog, _, var_outs = fluid.io.load_inference_model(fluid_model_dir, exe) prog, _, var_outs = fluid.io.load_inference_model(fluid_model_dir, exe)
out_names = var_outs # HINT: pass var if fetch ops already created out_names = var_outs # HINT: pass var if fetch ops already created
logger.info('model load passed') logger.info('model load passed')
elif basename.endswith('.py'): # is Python code elif basename.endswith('.py'): # is Python code
logger.debug('using code file %s', basename) logger.info('using code file %s', basename)
module_name, _ = os.path.splitext(basename) module_name, _ = os.path.splitext(basename)
sys_path = sys.path.copy() sys_path = sys.path.copy()
sys.path.append(fluid_model_dir) sys.path.append(fluid_model_dir)
...@@ -91,18 +134,28 @@ def validate(fluid_model_filename, ...@@ -91,18 +134,28 @@ def validate(fluid_model_filename,
if golden_data_filename: if golden_data_filename:
logger.info('using golden data %s', golden_data_filename) logger.info('using golden data %s', golden_data_filename)
if golden_data_filename.endswith('.npz'): if golden_data_filename.endswith('.npz'):
test_data = np.load(golden_data_filename, encoding='bytes') test_data = np.load(
golden_data_filename,
encoding='bytes',
allow_pickle=True,
)
input_data = test_data['inputs'].tolist() input_data = test_data['inputs'].tolist()
output_data = test_data['outputs'].tolist() output_data = test_data['outputs'].tolist()
else: else:
test_data = np.load(golden_data_filename, encoding='bytes').tolist() test_data = np.load(
golden_data_filename,
encoding='bytes',
allow_pickle=True,
).tolist()
input_data = test_data['inputs'] input_data = test_data['inputs']
output_data = test_data['outputs'] output_data = test_data['outputs']
input_data = flatten_dict(input_data) input_data = flatten_dict(input_data)
output_data = flatten_dict(output_data) output_data = flatten_dict(output_data)
input_names = input_data.keys() input_names = input_data.keys()
logger.info('found %d I/O golden data, starting test ...', output_names = output_data.keys()
len(input_data) + len(output_data)) logger.info('with %d inputs and %d outputs', len(input_data),
len(output_data))
else: else:
assert inference_input_names, 'input names required for type-shape inference' assert inference_input_names, 'input names required for type-shape inference'
...@@ -111,25 +164,7 @@ def validate(fluid_model_filename, ...@@ -111,25 +164,7 @@ def validate(fluid_model_filename,
# type-shape inference and re-save # type-shape inference and re-save
if save_inference_model: if save_inference_model:
for block in prog.blocks: fluid_prog_shape_infer(prog)
block_desc = block.desc
for idx_op in range(block_desc.op_size()):
op_desc = block_desc.op(idx_op)
if op_desc.type() in ('feed', 'fetch'):
continue
op_desc.infer_var_type(block_desc)
op_desc.infer_shape(block_desc)
for var_name, var in block.vars.items():
var_desc = var.desc
if var_desc.type() != fluid.core.VarDesc.VarType.LOD_TENSOR:
continue
# WORKAROUND: dirty way to give dtype to partial-infered vars
# which could not be cleared!
try:
var.to_string(True)
except ValueError:
var_desc.set_dtype(fluid.core.VarDesc.VarType.FP32)
fluid.io.save_inference_model(fluid_model_dir, fluid.io.save_inference_model(fluid_model_dir,
input_names, input_names,
var_outs, var_outs,
...@@ -151,7 +186,7 @@ def validate(fluid_model_filename, ...@@ -151,7 +186,7 @@ def validate(fluid_model_filename,
# validate # validate
passed = True passed = True
for (name, truth), output in zip(output_data.items(), outputs): for (name, truth), output in zip(output_data.items(), outputs):
logger.info('testing output {} ...'.format(name)) logger.info('testing on output {} ...'.format(name))
try: try:
np.testing.assert_allclose(output, np.testing.assert_allclose(output,
truth, truth,
...@@ -162,10 +197,7 @@ def validate(fluid_model_filename, ...@@ -162,10 +197,7 @@ def validate(fluid_model_filename,
except AssertionError as e: except AssertionError as e:
passed = False passed = False
logger.error('failed: %s\n', e) logger.error('failed: %s\n', e)
if passed: logger.info('accuracy %spassed', '' if passed else 'not ')
logger.info('accuracy passed')
else:
logger.info('accuracy not passed')
return passed return passed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册