提交 ba40d265 编写于 作者: M Macrobull

optimize symbolic

上级 52a502df
......@@ -69,6 +69,22 @@ parser.add_argument(
dest='pedantic',
help='process non-standard ONNX ops, this may lead to fails',
)
parser.add_argument(
'--skip-version-conversion',
'-y',
action='store_true',
default=False,
help='skip ONNX op version conversion, workaround for RumtimeErrors',
)
parser.add_argument(
'--archive',
'-z',
nargs='?',
type=str,
default=None,
const='',
help='compress outputs to ZIP file if conversion successed',
)
parser.add_argument(
'--precision',
'-p',
......
......@@ -16,10 +16,10 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
# import logging, shutil, zipfile
import logging
import shutil
import zipfile
import logging, shutil, zipfile
#import logging
#import shutil
#import zipfile
__all__ = [
'main',
......@@ -49,12 +49,14 @@ def main(**kwargs):
basepath, _ = shutil.os.path.splitext(filename)
save_dir = kwargs.get('output_dir', '')
# model.onnx -> model/
save_dir = shutil.os.path.dirname(save_dir) if save_dir else basepath
save_dir = (save_dir.rstrip('/') if save_dir else basepath) + '/'
model_basename = DEFAULT_MODEL_MODULE + '.py'
model_func_name = DEFAULT_MODEL_FUNC
embed_params = kwargs.get('embed_params', False)
onnx_opset_version = DEFAULT_ONNX_OPSET_VERSION
onnx_opset_pedantic = kwargs.get('pedantic', True)
onnx_skip_version_conversion = kwargs.get('skip_version_conversion', False)
archive = kwargs.get('archive', None)
# convert
convert(
......@@ -65,6 +67,7 @@ def main(**kwargs):
embed_params=embed_params,
onnx_opset_version=onnx_opset_version,
onnx_opset_pedantic=onnx_opset_pedantic,
onnx_skip_version_conversion=onnx_skip_version_conversion,
debug=debug)
# validate
......@@ -104,13 +107,21 @@ def main(**kwargs):
return
# create zip file
fn_zip = save_dir.rstrip('/') + '.zip'
logger.info('compressing file to %s ...', fn_zip)
fz = zipfile.ZipFile(fn_zip, 'w', compression=zipfile.ZIP_LZMA)
for fn in shutil.os.listdir(save_dir):
fz.write(shutil.os.path.join(save_dir, fn), arcname=fn)
fz.close()
logger.info('compressing done')
if archive is not None:
if archive == '':
archive = save_dir.rstrip('/') + '.zip'
logger.info('compressing file to %s ...', archive)
shutil.sys.stderr.write('\n')
shutil.sys.stderr.flush()
file_list = shutil.os.listdir(save_dir)
fz = zipfile.ZipFile(archive, 'w', compression=zipfile.ZIP_LZMA)
for idx, fn in enumerate(file_list):
shutil.sys.stderr.write('\033[F\033[2K')
logger.info('file {}/{}: {}'.format(idx + 1, len(file_list), fn))
shutil.sys.stderr.flush()
fz.write(shutil.os.path.join(save_dir, fn), arcname=fn)
fz.close()
logger.info('compressing done')
if __name__ == '__main__':
......@@ -132,5 +143,6 @@ if __name__ == '__main__':
output_dir='/tmp/export/',
embed_params=True,
pedantic=False,
skip_version_conversion=False,
test_data='../examples/inception_v2/test_data_set_2.npz',
debug=True)
......@@ -8,9 +8,9 @@ Created on Mon Feb 25 09:50:35 2019
from __future__ import division
# import logging, shutil
import logging
import shutil
import logging, shutil
#import logging
#import shutil
__all__ = [
'convert',
......@@ -24,6 +24,7 @@ def convert(onnx_model_filename,
embed_params=False,
onnx_opset_version=9,
onnx_opset_pedantic=True,
onnx_skip_version_conversion=False,
debug=False):
"""
convert an ONNX model to Paddle fluid Python code and desc pb
......@@ -60,12 +61,13 @@ def convert(onnx_model_filename,
try:
logger.info('checking model ...')
check_model(onnx_model)
logger.debug('using opset version: %d', onnx_opset_version)
if onnx_opset_pedantic: # WORKAROUND: RuntimeError: No Adapter For OP
onnx_model = convert_version(onnx_model, onnx_opset_version)
else: # TODO: add new argument for this option
if onnx_skip_version_conversion: # WORKAROUND: RuntimeError: No Adapter For OP
logger.debug('assumed opset version: %d', onnx_opset_version)
logger.warning(
'opset conversion skipped for onnx_opset_pedantic is OFF')
else:
logger.debug('using opset version: %d', onnx_opset_version)
onnx_model = convert_version(onnx_model, onnx_opset_version)
onnx_model = polish_model(onnx_model)
except ValidationError as e:
if onnx_opset_pedantic:
......@@ -152,16 +154,15 @@ def convert(onnx_model_filename,
logger.info(
'weight %s is shared between ops, more disk space will be consumed',
name)
logger.debug(
'saving weight %s with size of %d, in %d bytes, as %s ...',
name, weight.size, weight.nbytes, var_names)
logger.debug('saving weight %s(%s[%d], %dB) as %s ...', name,
weight.dtype, weight.size, weight.nbytes, var_names)
for var_name in var_names: # multiple references
fluid_writer.write_weight(
weight, shutil.os.path.join(save_dir, var_name))
else:
logger.debug(
'saving weight %s with size of %d, in %d bytes, to %s ...',
name, weight.size, weight.nbytes, make_var_name(name))
logger.debug('saving weight %s(%s[%d], %dB) to %s ...', name,
weight.dtype, weight.size, weight.nbytes,
make_var_name(name))
fluid_writer.write_weight(
weight, shutil.os.path.join(save_dir, make_var_name(name)))
fluid_writer.emit_param(fluid_program, name, value_info)
......@@ -262,6 +263,13 @@ if __name__ == '__main__':
dest='pedantic',
help='process non-standard ONNX ops, this may lead to fails',
)
parser.add_argument(
'--skip-version-conversion',
'-y',
action='store_true',
default=False,
help='skip ONNX op version conversion, workaround for RumtimeErrors',
)
args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
......@@ -273,10 +281,12 @@ if __name__ == '__main__':
save_dir = args.output_dir
embed_params = args.embed_params
pedantic = args.pedantic
skip_version_conversion = args.skip_version_conversion
convert(
model_filename,
save_dir,
embed_params=embed_params,
onnx_opset_pedantic=pedantic,
onnx_skip_version_conversion=skip_version_conversion,
debug=debug)
......@@ -26,6 +26,7 @@ __all__ = [
'node_attrs',
'node_topo',
'node_iter',
'tensor_dtype',
'tensor_shape',
'graph_ops',
'graph_weights',
......@@ -92,13 +93,12 @@ def get_attribute_value2(attr):
return value
def node_attrs(node):
def tensor_dtype(tensor):
"""
convert ONNX node attributes to dict
get ONNX tensor in np.dtype
"""
return {attr.name: get_attribute_value2(attr)
for attr in node.attribute} # dict
return TENSOR_TYPE_TO_NP_TYPE[tensor.type.tensor_type.elem_type]
def tensor_shape(tensor):
......@@ -109,6 +109,15 @@ def tensor_shape(tensor):
return [dim.dim_value for dim in tensor.type.tensor_type.shape.dim]
def node_attrs(node):
"""
convert ONNX node attributes to dict
"""
return {attr.name: get_attribute_value2(attr)
for attr in node.attribute} # dict
def node_topo(nodes, topo='default'):
"""
build indices with given topology to an ONNX node graph
......@@ -237,21 +246,21 @@ def inferred_model_value_info(model):
value_info = Dict()
for item in graph.value_info:
value_info[item.name] = dict(
dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
dtype=tensor_dtype(item),
shape=tensor_shape(item),
external=False,
)
for item in graph.input:
assert item.name not in value_info
value_info[item.name] = dict(
dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
dtype=tensor_dtype(item),
shape=tensor_shape(item),
external=True,
)
for item in graph.output:
# assert item.name not in value_info, 'bypass-model not supported'
value_info[item.name] = dict(
dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
dtype=tensor_dtype(item),
shape=tensor_shape(item),
external=True,
)
......@@ -373,9 +382,9 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
elif not keep_input_only and name in output_refs:
ret_initializers.add().CopyFrom(initializer)
else:
logger.debug('initializer %s(%s[%d]) stripped', name,
TENSOR_TYPE_TO_NP_TYPE[initializer.data_type],
len(initializer.raw_data))
dtype = TENSOR_TYPE_TO_NP_TYPE[initializer.data_type]
logger.debug('initializer %s(%s[%d]) stripped', name, dtype,
len(initializer.raw_data) // dtype.itemsize)
# strip inputs
ret.graph.ClearField('input')
......@@ -385,10 +394,8 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
if name in input_refs or name in out_names:
ret_inputs.add().CopyFrom(item)
else:
logger.debug(
'input %s(%s%s) stripped', name,
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
tensor_shape(item))
logger.debug('input %s(%s%s) stripped', name, tensor_dtype(item),
tensor_shape(item))
return ret
......
......@@ -19,7 +19,7 @@ from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
_logger = _logging.getLogger(__name__)
ONNX_INT_MAX = 2**63 - 1
FLUID_INT_MAX = 2**31 - 1
FLUID_INT_MAX = 2**31 - 1 #
DEFAULT_ONNX_OP_DOMAIN = ''
DEFAULT_FLUID_OP_NAMESCOPE = '/'
......@@ -186,13 +186,17 @@ def _shape_or_none(value_infos, val_name):
return list(value_info['shape'])
#def _maybe_const_value(value_infos, val_name):
# var_name = _make_var_name(val_name)
# if val_name not in value_infos:
# return var_name
# value_info = value_infos[val_name]
# assert value_info.get('remove_batch', False) == False, 'const value should not have batch dim'
# return value_info.get('const_value', var_name)
def _const_weight_or_none(value_infos, val_name):
if val_name not in value_infos:
return None
value_info = value_infos[val_name]
const_value = value_info.get('const_value', None)
if const_value:
return const_value
get_weight_func = value_info.get('get_weight', None)
if get_weight_func:
return get_weight_func()
return None
def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
......@@ -253,7 +257,7 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
num_vars = len(var_outs)
num_args = len(fluid_output_args)
if num_vars < num_args:
assert fill_name_field, 'name required to naming dummy output variable'
assert fill_name_field, 'name required to name dummy output variables'
for idx_out in range(num_vars, num_args):
var_out = _make_var_name(name + '.' +
fluid_output_args[idx_out].lower())
......@@ -294,9 +298,8 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
if pads[idx_dim] != pads[ndims + idx_dim]:
symmetric = False
break
if symmetric:
return pads[:ndims], None
return pads[:ndims], val_name
val_padded = val_name + '_padded'
prog.Op(
......@@ -315,13 +318,7 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
return [0] * ndims, val_padded
def _adaptive_pool(prog,
pool_type,
inputs,
outputs,
attrs,
value_infos,
name=''):
def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''):
# I/O
val_x, = inputs
val_y, = outputs[:1]
......@@ -335,10 +332,6 @@ def _adaptive_pool(prog,
# interpretation
pool_size = attrs['output_size'] # required
output_shape = _shape_or_none(value_infos, val_y)
if output_shape is not None:
assert pool_size == output_shape[
2:], 'pool_size unmatches shape of Y' # NC...
poolnd = len(pool_size)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d supported'
......@@ -445,11 +438,9 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
fluid_op = 'pool{}d'.format(poolnd)
strides = attrs.get('strides', [1] * poolnd) # optional
pads = attrs.get('pads', [0] * len(pool_size * 2)) # optional
paddings, val_x_padded = _pad_if_asymmetric(prog, pads, val_x, value_infos)
if val_x_padded:
val_x = val_x_padded
ceil_mode = bool(attrs.get('ceil_mode', 0)) # optional
pads = attrs.get('pads', [0] * (poolnd * 2)) # optional
paddings, val_x = _pad_if_asymmetric(prog, pads, val_x, value_infos)
var_x = _make_var_name(val_x)
name_attr = ', name={}'.format(repr(name)) if name else ''
......@@ -506,17 +497,17 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
spatial_scale = attrs['spatial_scale'] # required
pooled_height, pooled_width = attrs['pooled_shape'] # required
od_attrs = dict(
spatial_scale=spatial_scale,
pooled_height=pooled_height,
pooled_width=pooled_width,
spatial_scale=spatial_scale,
)
feature_attr = ''
is_max_pool = fluid_op == 'roi_pool'
if 'sampling_ratio' in attrs:
if 'sampling_ratio' in attrs: #
sampling_ratio = attrs['sampling_ratio']
od_attrs['sampling_ratio'] = sampling_ratio
feature_attr += ', sampling_ratio={}'.format(sampling_ratio)
if 'output_channels' in attrs:
if 'output_channels' in attrs: #
output_channels = attrs['output_channels']
od_attrs['output_channels'] = output_channels
feature_attr += ', output_channels={}'.format(output_channels)
......@@ -560,36 +551,20 @@ def _zeros_like(prog, val_ref, val_out, value_infos):
)
def AdaptiveAveragePool(prog,
inputs,
outputs,
attrs,
value_infos,
name='',
*args,
**kwargs):
def AdaptiveAveragePool(prog, inputs, outputs, attrs, *args, name='', **kwargs):
"""
aten::adaptive_avg_poolnd
"""
return _adaptive_pool(
prog, 'avg', inputs, outputs, attrs, value_infos, name=name)
return _adaptive_pool(prog, 'avg', inputs, outputs, attrs, name=name)
def AdaptiveMaxPool(prog,
inputs,
outputs,
attrs,
value_infos,
name='',
*args,
**kwargs):
def AdaptiveMaxPool(prog, inputs, outputs, attrs, *args, name='', **kwargs):
"""
aten::adaptive_max_poolnd
"""
return _adaptive_pool(
prog, 'max', inputs, outputs, attrs, value_infos, name=name)
return _adaptive_pool(prog, 'max', inputs, outputs, attrs, name=name)
def AveragePool(prog,
......@@ -734,9 +709,9 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
var_output = _make_var_name(val_output)
# interpretation
dtype = attrs['to']
if not isinstance(dtype, np.dtype):
dtype = TENSOR_TYPE_TO_NP_TYPE[dtype] # required
dtype = attrs['to'] # required
if not isinstance(dtype, np.dtype): # additional: possible np.dtype
dtype = TENSOR_TYPE_TO_NP_TYPE[dtype]
output_dtype = _dtype_or_none(value_infos, val_output)
if output_dtype:
assert dtype == output_dtype, 'dtype of to unmatches output'
......@@ -818,15 +793,16 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
assert dtype == output_dtype, 'tensor dtype unmatches storage dtype'
# dtype = np.dtype('float32') # force to float32
# dtype = np.dtype('float32') # HINT: force to float32
shape = attrs.get('shape', None) # additional, maybe var_name
if shape is None:
shape = _shape_or_none(value_infos, val_output)
if shape is None:
shape = list(value.shape)
_logger.warning(
'shape of %s not inferred, using value as 1-D tensor may lead to fails',
val_output)
'in (Constant -> %s): '
'shape of %s not inferred, '
'using value as 1-D tensor may lead to fails', outputs, val_output)
# generation
if value.size == 1: # scalar
......@@ -855,18 +831,27 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
"""
# I/O
val_input, = inputs
val_shape, = inputs
val_output, = outputs
is_const_shape = 'const_value' in value_infos[val_input]
if is_const_shape:
shape = _make_var_name(val_input)
else:
shape = value_infos[val_input]['get_weight']()
shape = _const_weight_or_none(value_infos, val_shape)
if shape is None:
shape = _shape_or_none(value_infos, val_output)
assert shape is not None, (
'given shape is neither const value nor deductible from output, '
'this is not supported')
dtype = attrs['value'].dtype
attrs = attrs.copy()
attrs.update(dict(shape=shape, dtype=dtype)) # pass var_name
Constant(prog, [], outputs, attrs, value_infos)
prog.Op(
'',
'Constant',
[],
outputs, # val
attrs,
value_infos,
)
def Conv(prog,
......@@ -903,13 +888,11 @@ def Conv(prog,
num_out_channels = _shape(value_infos, val_w)[0] # OI...
fluid_op = 'conv{}d'.format(convnd)
num_groups = attrs.get('group', 1) # optional
strides = attrs.get('strides', [1] * convnd) # optional
pads = attrs.get('pads', [0] * convnd * 2) # optional
paddings, val_x_padded = _pad_if_asymmetric(prog, pads, val_x, value_infos)
if val_x_padded:
val_x = val_x_padded
dilations = attrs.get('dilations', [1] * convnd) # optional
num_groups = attrs.get('group', 1) # optional
pads = attrs.get('pads', [0] * (convnd * 2)) # optional
paddings, val_x = _pad_if_asymmetric(prog, pads, val_x, value_infos)
var_x = _make_var_name(val_x)
name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params:
......@@ -1014,13 +997,11 @@ def ConvTranspose(prog,
num_out_channels = _shape(value_infos, val_w)[1] # IO...
fluid_op = 'conv{}d_transpose'.format(convnd)
num_groups = attrs.get('group', 1) # optional
strides = attrs.get('strides', [1] * convnd) # optional
pads = attrs.get('pads', [0] * convnd * 2) # optional
paddings, val_x_padded = _pad_if_asymmetric(prog, pads, val_x, value_infos)
if val_x_padded:
val_x = val_x_padded
dilations = attrs.get('dilations', [1] * convnd) # optional
num_groups = attrs.get('group', 1) # optional
pads = attrs.get('pads', [0] * (convnd * 2)) # optional
paddings, val_x = _pad_if_asymmetric(prog, pads, val_x, value_infos)
var_x = _make_var_name(val_x)
name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params:
......@@ -1090,7 +1071,7 @@ def ConvTranspose(prog,
prog.VarDesc(var_y)
# should not appears
# should not appear
#def Dropout(
# prog, inputs, outputs, value_infos,
# *args, **kwargs):
......@@ -1154,10 +1135,16 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
else:
val_beta = name + '_beta' # explicit variable
val_vm = name + '_vm' # explicit variable
vm_dtype = _dtype_or_none(value_infos, val_c)
if vm_dtype is None:
vm_dtype = np.dtype('float32')
beta = np.dtype(vm_dtype).type(beta)
if beta.is_integer():
vm_dtype = _dtype_or_none(value_infos, val_c)
if vm_dtype is None:
vm_dtype = np.dtype('float32')
_logger.warning(
'in %s(%s -> Gemm -> %s): '
'beta seems to be an interger, '
'however dtype can not be inferred, '
'still use float32', name, inputs, outputs)
beta = np.dtype(vm_dtype).type(beta)
prog.Op(
'',
'Constant',
......@@ -1429,13 +1416,15 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
var_reshaped = _make_var_name(val_reshaped)
# interpretation
fluid_op = 'reshape'
is_const_shape = 'const_value' in value_infos[val_shape]
var_shape = _make_var_name(val_shape) # for code
if is_const_shape:
shape = value_infos[val_shape]['const_value'] # for desc
else:
shape = value_infos[val_shape]['get_weight']() # for desc
shape = _const_weight_or_none(value_infos, val_shape)
is_const_shape = shape and 'const_value' in value_infos[val_shape]
if shape is None:
shape = _shape_or_none(value_infos, var_reshaped)
assert shape is not None, (
'given shape is neither const value nor deductible from output, '
'this is not supported')
fluid_op = 'reshape'
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
......@@ -1457,7 +1446,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'Cast',
[var_shape],
[var_shape_int32], # var
dict(to=np.dtype('int32')),
dict(to=np.dtype('int32')), # use np.dtype
value_infos=value_infos,
name=(name + '_cast'),
)
......@@ -1593,26 +1582,25 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
var_output = _make_var_name(val_output)
# interpretation
repeats = _const_weight_or_none(value_infos, val_repeats)
assert repeats is not None, 'only const repeats is supported'
fluid_op = 'expand'
is_const_repeats = 'const_value' in value_infos[val_repeats]
if is_const_repeats:
code_repeats = _make_var_name(val_repeats) # for code
repeats = value_infos[val_repeats]['const_value'] # for desc
else:
repeats = value_infos[val_input]['get_weight']() # for desc
code_repeats = repeats # for code
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
prog.Code('{} = layers.{}({}'
', expand_times={}'
'{})'.format(
'{})'
' # {} = {}'.format(
var_output,
fluid_op,
var_input,
# attrs
code_repeats,
repeats,
name_attr,
# comment
_make_var_name(val_repeats),
repeats,
))
prog.VarDesc(var_output)
prog.OpDesc(
......
......@@ -6,11 +6,12 @@ Created on Fri Mar 22 12:17:19 2019
@author: Macrobull
"""
# import importlib, logging, os, sys
import importlib
import logging
import os
import sys
import importlib, logging, os, sys
#import importlib
#import logging
#import os
#import sys
def _flatten_dict(obj, out=None):
......
......@@ -8,9 +8,9 @@ Created on Sun Feb 24 20:44:43 2019
from __future__ import division
# import logging, os
import logging
import os
import logging, os
#import logging
#import os
import numpy as np
logger = logging.getLogger(__name__)
......@@ -215,10 +215,6 @@ class Program(object):
var_desc.persistable = persistable
var_desc.type.type = framework_pb2.VarType.LOD_TENSOR
# REMOVEIT: WORKAROUND: Netron: null.tensor error
tensor_desc = var_desc.type.lod_tensor.tensor
tensor_desc.data_type = self.Dtype(dummy_dtype) # required
if value_info and 'dtype' in value_info:
tensor_desc = var_desc.type.lod_tensor.tensor
tensor_desc.data_type = self.Dtype(value_info['dtype']) # required
......@@ -230,6 +226,9 @@ class Program(object):
not persistable)
if remove_batch:
tensor_desc.dims[0] = -1
else: # REMOVEIT: WORKAROUND: Netron: null.tensor error
tensor_desc = var_desc.type.lod_tensor.tensor
tensor_desc.data_type = self.Dtype(dummy_dtype) # required
self.var_descs.append(var_desc)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册