提交 826481c4 编写于 作者: M Macrobull

add type shape inference

上级 7c3e9379
...@@ -21,7 +21,7 @@ def make_var_name(name): ...@@ -21,7 +21,7 @@ def make_var_name(name):
return '_' return '_'
if name[0].isdigit(): if name[0].isdigit():
return 'var_' + name return 'var_' + name
for s in ' \\|/:': # for s in ' \\|/:-': #
name = name.replace(s, '_') name = name.replace(s, '_')
if name.startswith('_'): if name.startswith('_'):
name = 'var' + name name = 'var' + name
......
...@@ -24,7 +24,7 @@ def make_var_name(name): ...@@ -24,7 +24,7 @@ def make_var_name(name):
return '_' return '_'
if name[0].isdigit(): if name[0].isdigit():
return 'var_' + name return 'var_' + name
for s in ' \\|/:': # for s in ' \\|/:-': #
name = name.replace(s, '_') name = name.replace(s, '_')
if name.startswith('_'): if name.startswith('_'):
name = 'var' + name name = 'var' + name
......
...@@ -311,7 +311,7 @@ resnet100_arcface() ...@@ -311,7 +311,7 @@ resnet100_arcface()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid -o /tmp/export/ "$fn_model" -y python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar"/*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
......
...@@ -95,6 +95,14 @@ parser.add_argument( ...@@ -95,6 +95,14 @@ parser.add_argument(
default=1e-2, default=1e-2,
help='assertion relative tolerance for validation', help='assertion relative tolerance for validation',
) )
parser.add_argument(
'--infer_inputs',
'-i',
nargs='?',
default=None,
const='',
help='perform type-shape inference with given input names and re-save model',
)
args = parser.parse_args() args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s' logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
......
...@@ -60,19 +60,28 @@ def main(**kwargs): ...@@ -60,19 +60,28 @@ def main(**kwargs):
# validate # validate
passed = True passed = True
golden_data_filename = kwargs.pop('test_data', '') golden_data_filename = kwargs.pop('test_data', '')
if golden_data_filename: infer_inputs = kwargs.pop('infer_inputs', None)
if golden_data_filename or infer_inputs:
from .validation import validate from .validation import validate
save_inference_model = infer_inputs is not None
inference_input_names = infer_inputs.split(
',') if infer_inputs else None
logger.info('starting validation on desc ...') logger.info('starting validation on desc ...')
passed &= validate(shutil.os.path.join(save_dir, '__model__'), passed &= validate(shutil.os.path.join(save_dir, '__model__'),
golden_data_filename, **kwargs) golden_data_filename=golden_data_filename,
save_inference_model=save_inference_model,
inference_input_names=inference_input_names,
**kwargs)
logger.info('starting validation on code ...') logger.info('starting validation on code ...')
# this re-generate desc proto with Python code when debug on # this re-generate desc proto with Python code when debug on
passed &= validate(shutil.os.path.join(save_dir, model_basename), passed &= validate(shutil.os.path.join(save_dir, model_basename),
golden_data_filename, golden_data_filename=golden_data_filename,
model_func_name=model_func_name, model_func_name=model_func_name,
save_inference_model=debug, save_inference_model=save_inference_model,
inference_input_names=inference_input_names,
**kwargs) **kwargs)
if not passed: if not passed:
......
...@@ -141,23 +141,22 @@ def convert(onnx_model_filename, ...@@ -141,23 +141,22 @@ def convert(onnx_model_filename,
logger.info('%d ops in, %d ops out', len(onnx_graph.node), logger.info('%d ops in, %d ops out', len(onnx_graph.node),
len(fluid_program.op_descs)) len(fluid_program.op_descs))
# shape-inference # type-shape inference
for name, value_info in graph_value_infos.items(): for name, value_info in graph_value_infos.items():
var_name = make_var_name(name) var_name = make_var_name(name)
fluid_program.VarTypeInfo(var_name, value_info, fluid_program.VarTypeShapeInfo(var_name, value_info,
remove_batch=False) # shape-infer only remove_batch=False) # shape-infer only
bad_var_names = [] bad_var_names = []
for var_name, var_desc in fluid_program.var_descs.items(): for var_name, var_desc in fluid_program.var_descs.items():
if not var_desc.type.lod_tensor.HasField('tensor'): if not var_desc.type.lod_tensor.HasField('tensor'):
bad_var_names.append(var_name) bad_var_names.append(var_name)
if len(bad_var_names) > 0: if len(bad_var_names) > 0:
logger.warning('type info not infered for var %s ...', logger.warning('type-shape not infered for var %s ...',
', '.join(bad_var_names[:5])) ', '.join(bad_var_names[:5]))
logger.warning('this causes little problem for PaddlePaddle, ' logger.warning('this causes little problem for PaddlePaddle, '
'but Paddle Mobile may not infer correctly') 'but Paddle Mobile may not infer correctly')
logger.warning( logger.warning('please consider running onnx2fluid.validation with -i '
'please consider adding option -d to invoke PaddlePaddle shape-inference' 'to invoke PaddlePaddle type-shape inference')
)
# weight writer # weight writer
for name, weight in graph_weights(onnx_graph): for name, weight in graph_weights(onnx_graph):
...@@ -233,13 +232,9 @@ def convert(onnx_model_filename, ...@@ -233,13 +232,9 @@ def convert(onnx_model_filename,
logger.info('conversion finished') logger.info('conversion finished')
if __name__ == '__main__': def main():
del convert
import argparse import argparse
from onnx2fluid.conversion import convert
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='onnx2fluid.convert', description='onnx2fluid.convert',
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
...@@ -310,3 +305,11 @@ if __name__ == '__main__': ...@@ -310,3 +305,11 @@ if __name__ == '__main__':
onnx_opset_pedantic=pedantic, onnx_opset_pedantic=pedantic,
onnx_skip_version_conversion=skip_version_conversion, onnx_skip_version_conversion=skip_version_conversion,
debug=debug) debug=debug)
if __name__ == '__main__':
del convert
from onnx2fluid.conversion import convert
main()
...@@ -210,7 +210,7 @@ def node_iter(nodes, indices=None): ...@@ -210,7 +210,7 @@ def node_iter(nodes, indices=None):
if name == '': if name == '':
name = 'op_' + str(index) name = 'op_' + str(index)
else: # make_op_name else: # make_op_name
for s in ' \\|/:': # for s in ' \\|/:-': #
name = name.replace(s, '_') name = name.replace(s, '_')
if domain == '': if domain == '':
domain = DEFAULT_OP_DOMAIN domain = DEFAULT_OP_DOMAIN
......
...@@ -153,7 +153,7 @@ def _make_var_name(name): ...@@ -153,7 +153,7 @@ def _make_var_name(name):
return '_' return '_'
if name[0].isdigit(): if name[0].isdigit():
return 'var_' + name return 'var_' + name
for s in ' \\|/:': # for s in ' \\|/:-': #
name = name.replace(s, '_') name = name.replace(s, '_')
if name.startswith('_'): if name.startswith('_'):
name = 'var' + name name = 'var' + name
...@@ -191,14 +191,24 @@ def _const_weight_or_none(value_infos, val_name): ...@@ -191,14 +191,24 @@ def _const_weight_or_none(value_infos, val_name):
return None return None
value_info = value_infos[val_name] value_info = value_infos[val_name]
const_value = value_info.get('const_value', None) const_value = value_info.get('const_value', None)
if const_value: if const_value is not None:
return const_value return const_value
get_weight_func = value_info.get('get_weight', None) get_weight_func = value_info.get('get_weight', None)
if get_weight_func: if get_weight_func is not None:
return get_weight_func() return get_weight_func()
return None return None
def _check_embeddable(value_infos, *val_names):
keyword = 'get_weight'
for val_name in val_names:
if keyword not in value_infos[val_name]:
_logger.warning('parameter %s not embeddable for some ops',
val_name)
return False
return True
def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs): def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
info = DEFAULT_OP_MAPPING[op_type] info = DEFAULT_OP_MAPPING[op_type]
info.extend(DEFAULT_OP_MAPPING_VALUES[len(info):]) info.extend(DEFAULT_OP_MAPPING_VALUES[len(info):])
...@@ -391,9 +401,9 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''): ...@@ -391,9 +401,9 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
input_shape = _shape_or_none(value_infos, val_x) input_shape = _shape_or_none(value_infos, val_x)
output_shape = _shape_or_none(value_infos, val_y) output_shape = _shape_or_none(value_infos, val_y)
assert input_shape is not None or output_shape is not None, 'poolnd not inferred' # NC... assert input_shape is not None or output_shape is not None, 'poolnd not inferred' # NC...
if input_shape: if input_shape is not None:
poolnd = len(input_shape) - 2 # NC... poolnd = len(input_shape) - 2 # NC...
elif output_shape: elif output_shape is not None:
poolnd = len(output_shape) - 2 # NC... poolnd = len(output_shape) - 2 # NC...
assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported' assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported'
...@@ -568,7 +578,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''): ...@@ -568,7 +578,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
1] == 1, 'only scale on (NC)HW supported' 1] == 1, 'only scale on (NC)HW supported'
assert scales[2] == scales[ assert scales[2] == scales[
3], 'only aspect-ratio-invariant scale supported' 3], 'only aspect-ratio-invariant scale supported'
scale = scales[2] if scales else None scale = None if scales is None else scales[2]
# try input shape # try input shape
if scale is None: if scale is None:
assert out_shape_, 'neither scales nor output shape is available' assert out_shape_, 'neither scales nor output shape is available'
...@@ -704,16 +714,19 @@ def BatchNormalization(prog, ...@@ -704,16 +714,19 @@ def BatchNormalization(prog,
momentum = attrs.get('momentum', .9) # optional momentum = attrs.get('momentum', .9) # optional
epsilon = attrs.get('epsilon', 1e-5) # optional epsilon = attrs.get('epsilon', 1e-5) # optional
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params:
embed_params = _check_embeddable(value_infos, val_scale, val_b,
val_mean, val_var)
if embed_params: if embed_params:
assert name != '' assert name != ''
var_scale = name + '.w_0' var_scale = name + '.w_0'
var_b = name + '.b_0' var_b = name + '.b_0'
var_mean = name + '.w_1' var_mean = name + '.w_1'
var_var = name + '.w_2' var_var = name + '.w_2'
value_infos[val_scale].setdefault('embeded_as', []).append(var_scale) value_infos[val_scale]['embeded_as'].append(var_scale)
value_infos[val_b].setdefault('embeded_as', []).append(var_b) value_infos[val_b]['embeded_as'].append(var_b)
value_infos[val_mean].setdefault('embeded_as', []).append(var_mean) value_infos[val_mean]['embeded_as'].append(var_mean)
value_infos[val_var].setdefault('embeded_as', []).append(var_var) value_infos[val_var]['embeded_as'].append(var_var)
param_attr = '' param_attr = ''
else: else:
var_scale = _make_var_name(val_scale) var_scale = _make_var_name(val_scale)
...@@ -774,7 +787,7 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -774,7 +787,7 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
if not isinstance(dtype, _np.dtype): # additional: possible np.dtype if not isinstance(dtype, _np.dtype): # additional: possible np.dtype
dtype = TENSOR_TYPE_TO_NP_TYPE[dtype] dtype = TENSOR_TYPE_TO_NP_TYPE[dtype]
output_dtype = _dtype_or_none(value_infos, val_output) output_dtype = _dtype_or_none(value_infos, val_output)
if output_dtype: if output_dtype is not None:
assert dtype == output_dtype, 'dtype of to unmatches output' assert dtype == output_dtype, 'dtype of to unmatches output'
fluid_op = 'cast' fluid_op = 'cast'
...@@ -843,7 +856,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -843,7 +856,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
""" """
# I/O # I/O
assert len(inputs) == 0 assert len(inputs) == 0, 'constant op accept no inputs'
val_output, = outputs val_output, = outputs
var_output = _make_var_name(val_output) var_output = _make_var_name(val_output)
...@@ -851,7 +864,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -851,7 +864,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
value = attrs['value'] # required value = attrs['value'] # required
dtype = _np.dtype(value.dtype) dtype = _np.dtype(value.dtype)
output_dtype = _dtype_or_none(value_infos, val_output) output_dtype = _dtype_or_none(value_infos, val_output)
if output_dtype: if output_dtype is not None:
assert dtype == output_dtype, 'tensor dtype unmatches storage dtype' assert dtype == output_dtype, 'tensor dtype unmatches storage dtype'
...@@ -970,13 +983,16 @@ def Conv(prog, ...@@ -970,13 +983,16 @@ def Conv(prog,
paddings, val_x = _pad_if_asymmetric(prog, pads, val_x, value_infos) paddings, val_x = _pad_if_asymmetric(prog, pads, val_x, value_infos)
var_x = _make_var_name(val_x) var_x = _make_var_name(val_x)
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params:
embed_params = (_check_embeddable(value_infos, val_w) and not has_bias
or _check_embeddable(value_infos, val_b))
if embed_params: if embed_params:
assert name != '' assert name != ''
var_w = name + '.w_0' var_w = name + '.w_0'
value_infos[val_w].setdefault('embeded_as', []).append(var_w) value_infos[val_w]['embeded_as'].append(var_w)
if has_bias: if has_bias:
var_b = name + '.b_0' var_b = name + '.b_0'
value_infos[val_b].setdefault('embeded_as', []).append(var_b) value_infos[val_b]['embeded_as'].append(var_b)
param_attr = '' param_attr = ''
else: else:
param_attr = ', bias_attr=False' param_attr = ', bias_attr=False'
...@@ -1080,13 +1096,16 @@ def ConvTranspose(prog, ...@@ -1080,13 +1096,16 @@ def ConvTranspose(prog,
paddings, val_x = _pad_if_asymmetric(prog, pads, val_x, value_infos) paddings, val_x = _pad_if_asymmetric(prog, pads, val_x, value_infos)
var_x = _make_var_name(val_x) var_x = _make_var_name(val_x)
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params:
embed_params = (_check_embeddable(value_infos, val_w) and not has_bias
or _check_embeddable(value_infos, val_b))
if embed_params: if embed_params:
assert name != '' assert name != ''
var_w = name + '.w_0' var_w = name + '.w_0'
value_infos[val_w].setdefault('embeded_as', []).append(var_w) value_infos[val_w]['embeded_as'].append(var_w)
if has_bias: if has_bias:
var_b = name + '.b_0' var_b = name + '.b_0'
value_infos[val_b].setdefault('embeded_as', []).append(var_b) value_infos[val_b]['embeded_as'].append(var_b)
param_attr = '' param_attr = ''
else: else:
param_attr = ', bias_attr=False' param_attr = ', bias_attr=False'
...@@ -1330,9 +1349,9 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): ...@@ -1330,9 +1349,9 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
assume_pad2d = False assume_pad2d = False
if len(pads) == 4: if len(pads) == 4:
assume_pad2d |= mode != 'constant' assume_pad2d |= mode != 'constant'
if data_shape: if data_shape is not None:
assume_pad2d |= data_shape and len(data_shape) == 4 # NCHW assume_pad2d |= data_shape and len(data_shape) == 4 # NCHW
if output_shape: if output_shape is not None:
assume_pad2d |= output_shape and len(output_shape) == 4 # NCHW assume_pad2d |= output_shape and len(output_shape) == 4 # NCHW
od_attrs = {'pad_value': value} od_attrs = {'pad_value': value}
if assume_pad2d: if assume_pad2d:
...@@ -1404,10 +1423,12 @@ def PRelu(prog, ...@@ -1404,10 +1423,12 @@ def PRelu(prog,
mode = 'element' mode = 'element'
fluid_op = 'prelu' fluid_op = 'prelu'
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params:
embed_params = _check_embeddable(value_infos, val_slope)
if embed_params: if embed_params:
assert name != '' assert name != ''
var_slope = name + '.w_0' var_slope = name + '.w_0'
value_infos[val_slope].setdefault('embeded_as', []).append(var_slope) value_infos[val_slope]['embeded_as'].append(var_slope)
param_attr = '' param_attr = ''
else: else:
var_slope = _make_var_name(val_slope) var_slope = _make_var_name(val_slope)
...@@ -1474,6 +1495,8 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1474,6 +1495,8 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
# generation # generation
val_shape_int32 = val_shape + '_int32' # explicit variable
var_shape_int32 = _make_var_name(val_shape_int32)
prog.Code('# shape:{}={} # const as literal'.format(var_shape, shape)) prog.Code('# shape:{}={} # const as literal'.format(var_shape, shape))
if is_const_shape: if is_const_shape:
prog.Code('{} = layers.{}({}' prog.Code('{} = layers.{}({}'
...@@ -1487,8 +1510,6 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1487,8 +1510,6 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
name_attr, name_attr,
)) ))
else: else:
val_shape_int32 = val_shape + '_int32' # explicit variable
var_shape_int32 = _make_var_name(val_shape_int32)
prog.Op( prog.Op(
'', '',
'Cast', 'Cast',
...@@ -1514,14 +1535,6 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1514,14 +1535,6 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
var_xshape = name + '.xshape' # dummy output var_xshape = name + '.xshape' # dummy output
prog.VarDesc(var_reshaped) prog.VarDesc(var_reshaped)
prog.VarDesc(var_xshape) prog.VarDesc(var_xshape)
if is_const_shape:
prog.OpDesc(
fluid_op,
([var_data], 'X'),
([var_reshaped, var_xshape], 'Out', 'XShape'),
{'shape': shape},
)
else:
prog.OpDesc( prog.OpDesc(
fluid_op, fluid_op,
([var_data, var_shape_int32], 'X', 'Shape'), ([var_data, var_shape_int32], 'X', 'Shape'),
...@@ -1595,7 +1608,7 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1595,7 +1608,7 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
starts = attrs['starts'] # required starts = attrs['starts'] # required
ends = attrs['ends'] # required ends = attrs['ends'] # required
shape = _shape_or_none(value_infos, val_data) shape = _shape_or_none(value_infos, val_data)
if shape: if shape is not None:
# ndims = len(shape) # ndims = len(shape)
# for idx, value in enumerate(axes): # for idx, value in enumerate(axes):
# if value > ONNX_INT_MAX // 2: # if value > ONNX_INT_MAX // 2:
......
...@@ -25,7 +25,8 @@ def ensure_tuple(obj): ...@@ -25,7 +25,8 @@ def ensure_tuple(obj):
def flatten_list(obj, out=None): def flatten_list(obj, out=None):
assert isinstance(obj, list) assert isinstance(obj, list), 'list type required'
if out is None: if out is None:
out = type(obj)() out = type(obj)()
for item in obj: for item in obj:
...@@ -42,7 +43,7 @@ def export_data(state_dict, prefix=''): ...@@ -42,7 +43,7 @@ def export_data(state_dict, prefix=''):
""" """
def str_(obj): def str_(obj):
if isinstance(obj, (tuple, list)): if isinstance(obj, (tuple, list, set)):
return str(obj)[1:-1].replace(' ', '') return str(obj)[1:-1].replace(' ', '')
return str(obj) return str(obj)
......
...@@ -10,14 +10,15 @@ import importlib, logging, os, sys ...@@ -10,14 +10,15 @@ import importlib, logging, os, sys
def flatten_dict(obj, out=None): def flatten_dict(obj, out=None):
assert isinstance(obj, dict) assert isinstance(obj, dict), 'dict type required'
if out is None: if out is None:
out = type(obj)() out = type(obj)()
for key, value in obj.items(): for key, value in obj.items():
if isinstance(value, dict): if isinstance(value, dict):
flatten_dict(value, out) flatten_dict(value, out)
else: else:
assert key not in out assert key not in out, 'key conflicted'
out[key] = value out[key] = value
return out return out
...@@ -29,11 +30,12 @@ def ensure_list(obj): ...@@ -29,11 +30,12 @@ def ensure_list(obj):
def validate(fluid_model_filename, def validate(fluid_model_filename,
golden_data_filename, golden_data_filename='',
model_func_name='inference',
atol=1e-3, atol=1e-3,
rtol=1e-3, rtol=1e-3,
model_func_name='inference',
save_inference_model=False, save_inference_model=False,
inference_input_names=None,
**kwargs): **kwargs):
""" """
inference the converted Paddle fluid model, validate with given golden data inference the converted Paddle fluid model, validate with given golden data
...@@ -86,6 +88,7 @@ def validate(fluid_model_filename, ...@@ -86,6 +88,7 @@ def validate(fluid_model_filename,
raise ValueError('unsupported Paddle fluid model filename') raise ValueError('unsupported Paddle fluid model filename')
# load data # load data
if golden_data_filename:
logger.info('using golden data %s', golden_data_filename) logger.info('using golden data %s', golden_data_filename)
if golden_data_filename.endswith('.npz'): if golden_data_filename.endswith('.npz'):
test_data = np.load(golden_data_filename, encoding='bytes') test_data = np.load(golden_data_filename, encoding='bytes')
...@@ -97,13 +100,38 @@ def validate(fluid_model_filename, ...@@ -97,13 +100,38 @@ def validate(fluid_model_filename,
output_data = test_data['outputs'] output_data = test_data['outputs']
input_data = flatten_dict(input_data) input_data = flatten_dict(input_data)
output_data = flatten_dict(output_data) output_data = flatten_dict(output_data)
input_names = input_data.keys()
logger.info('found %d I/O golden data, starting test ...', logger.info('found %d I/O golden data, starting test ...',
len(input_data) + len(output_data)) len(input_data) + len(output_data))
else:
assert inference_input_names, 'input names required for type-shape inference'
input_names = inference_input_names
logger.info('using input names: %s', ', '.join(input_names))
# type-shape inference and re-save
if save_inference_model:
for block in prog.blocks:
block_desc = block.desc
for idx_op in range(block_desc.op_size()):
op_desc = block_desc.op(idx_op)
if op_desc.type() in ('feed', 'fetch'):
continue
op_desc.infer_var_type(block_desc)
op_desc.infer_shape(block_desc)
for var_name, var in block.vars.items():
var_desc = var.desc
if var_desc.type() != fluid.core.VarDesc.VarType.LOD_TENSOR:
continue
# WORKAROUND: dirty way to give dtype to partial-infered vars
# which could not be cleared!
try:
var.to_string(True)
except ValueError:
var_desc.set_dtype(fluid.core.VarDesc.VarType.FP32)
# DEBUG: reload test for Python code
if basename.endswith('.py') and save_inference_model:
fluid.io.save_inference_model(fluid_model_dir, fluid.io.save_inference_model(fluid_model_dir,
input_data.keys(), input_names,
var_outs, var_outs,
exe, exe,
main_program=prog, main_program=prog,
...@@ -112,8 +140,12 @@ def validate(fluid_model_filename, ...@@ -112,8 +140,12 @@ def validate(fluid_model_filename,
fluid.io.load_inference_model(fluid_model_dir, exe) fluid.io.load_inference_model(fluid_model_dir, exe)
logger.info('model re-load passed') logger.info('model re-load passed')
if not golden_data_filename:
return True
# execute # execute
outputs = exe.run(prog, feed=input_data, fetch_list=out_names) outputs = exe.run(prog, feed=input_data,
fetch_list=out_names) # out_names can be vars
logger.info('execution passed') logger.info('execution passed')
# validate # validate
...@@ -134,11 +166,10 @@ def validate(fluid_model_filename, ...@@ -134,11 +166,10 @@ def validate(fluid_model_filename,
logger.info('accuracy passed') logger.info('accuracy passed')
else: else:
logger.info('accuracy not passed') logger.info('accuracy not passed')
return passed return passed
if __name__ == '__main__': def main():
import argparse import argparse
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -160,6 +191,7 @@ if __name__ == '__main__': ...@@ -160,6 +191,7 @@ if __name__ == '__main__':
'--test_data', '--test_data',
'-t', '-t',
type=str, type=str,
default='',
help='I/O golden data for validation, e.g. test.npy, test.npz', help='I/O golden data for validation, e.g. test.npy, test.npz',
) )
parser.add_argument( parser.add_argument(
...@@ -175,19 +207,36 @@ if __name__ == '__main__': ...@@ -175,19 +207,36 @@ if __name__ == '__main__':
default=1e-2, default=1e-2,
help='assertion relative tolerance for validation', help='assertion relative tolerance for validation',
) )
parser.add_argument(
'--infer_inputs',
'-i',
nargs='?',
default=None,
const='',
help=
'perform type-shape inference with given input names and re-save model',
)
args = parser.parse_args() args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s' logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_level = logging.DEBUG if args.debug else logging.INFO logging_level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(format=logging_format, level=logging_level) logging.basicConfig(format=logging_format, level=logging_level)
debug = args.debug # debug = args.debug
fluid_model_filename = args.model[0] fluid_model_filename = args.model[0]
golden_data_filename = args.test_data golden_data_filename = args.test_data
atol, rtol = args.atol, args.rtol atol, rtol = args.atol, args.rtol
save_inference_model = args.infer_inputs is not None
inference_input_names = args.infer_inputs.split(
',') if args.infer_inputs else None
validate(fluid_model_filename, validate(fluid_model_filename,
golden_data_filename, golden_data_filename=golden_data_filename,
atol=atol, atol=atol,
rtol=rtol, rtol=rtol,
save_inference_model=debug) save_inference_model=save_inference_model,
inference_input_names=inference_input_names)
if __name__ == '__main__':
main()
...@@ -44,6 +44,8 @@ def irepr(obj, to='_'): ...@@ -44,6 +44,8 @@ def irepr(obj, to='_'):
def flatten_list(obj, out=None): def flatten_list(obj, out=None):
assert isinstance(obj, list), 'list type required'
if out is None: if out is None:
out = type(obj)() out = type(obj)()
for item in obj: for item in obj:
...@@ -59,9 +61,9 @@ def make_attr_name(name): ...@@ -59,9 +61,9 @@ def make_attr_name(name):
make a valid code name for ParamAttr make a valid code name for ParamAttr
""" """
if name == '': assert name != '', 'name should not be empty'
raise ValueError('name should not be empty')
for s in ' \\|/:': # for s in ' \\|/:-': #
name = name.replace(s, '_') name = name.replace(s, '_')
if not name.startswith('_'): if not name.startswith('_'):
name = '_' + name name = '_' + name
...@@ -195,11 +197,11 @@ class Program(object): ...@@ -195,11 +197,11 @@ class Program(object):
desc = framework_pb2.OpDesc() desc = framework_pb2.OpDesc()
desc.type = op_type desc.type = op_type
if input_val_keys is not None: if input_val_keys:
desc.inputs.extend(self.OpDescVars(*input_val_keys)) desc.inputs.extend(self.OpDescVars(*input_val_keys))
if output_val_keys is not None: if output_val_keys:
desc.outputs.extend(self.OpDescVars(*output_val_keys)) desc.outputs.extend(self.OpDescVars(*output_val_keys))
if attrs is not None: if attrs:
desc.attrs.extend(self.OpDescAttrs(attrs)) desc.attrs.extend(self.OpDescAttrs(attrs))
self.op_descs.append(desc) self.op_descs.append(desc)
return desc return desc
...@@ -220,8 +222,11 @@ class Program(object): ...@@ -220,8 +222,11 @@ class Program(object):
var_desc.persistable = persistable var_desc.persistable = persistable
var_desc.type.type = framework_pb2.VarType.LOD_TENSOR var_desc.type.type = framework_pb2.VarType.LOD_TENSOR
self.var_descs[var_name] = var_desc self.var_descs[var_name] = var_desc
if value_info: if value_info:
self.VarTypeInfo(var_name, value_info, remove_batch=remove_batch) self.VarTypeShapeInfo(var_name,
value_info,
remove_batch=remove_batch)
def Op(self, domain, op_type, *args, **kwargs): def Op(self, domain, op_type, *args, **kwargs):
""" """
...@@ -255,7 +260,7 @@ class Program(object): ...@@ -255,7 +260,7 @@ class Program(object):
else: else:
self.code_mutable = code_mutable self.code_mutable = code_mutable
def VarTypeInfo(self, var_name, value_info, remove_batch=None): def VarTypeShapeInfo(self, var_name, value_info, remove_batch=None):
""" """
set value_info for var set value_info for var
""" """
...@@ -429,8 +434,7 @@ class Writer(object): ...@@ -429,8 +434,7 @@ class Writer(object):
write single weight in fluid desc write single weight in fluid desc
""" """
if not isinstance(weight, np.ndarray): assert isinstance(weight, np.ndarray), 'weight is not an ndarray'
raise TypeError('weight is not an ndarray')
tensor_desc = framework_pb2.VarType.TensorDesc() tensor_desc = framework_pb2.VarType.TensorDesc()
tensor_desc.data_type = Program.Dtype(weight.dtype) tensor_desc.data_type = Program.Dtype(weight.dtype)
...@@ -452,8 +456,7 @@ class Writer(object): ...@@ -452,8 +456,7 @@ class Writer(object):
""" """
for name, weight in weights.items(): for name, weight in weights.items():
if not isinstance(weights, dict): assert isinstance(weights, dict), 'dict type weights required'
raise TypeError('dict type weights required')
var_name = make_var_name(name) var_name = make_var_name(name)
filename = os.path.join(save_dir, var_name) filename = os.path.join(save_dir, var_name)
......
...@@ -54,8 +54,8 @@ zip_safe = True ...@@ -54,8 +54,8 @@ zip_safe = True
[options.entry_points] [options.entry_points]
console_scripts = console_scripts =
onnx2fluid = onnx2fluid.__main__ onnx2fluid = onnx2fluid.__main__
onnx2fluid_convert = onnx2fluid.conversion onnx2fluid_convert = onnx2fluid.conversion:main
onnx2fluid_validate = onnx2fluid.validation onnx2fluid_validate = onnx2fluid.validation:main
# 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下 # 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下
# 仅支持文件,不支持目录,但可以使用通配 # 仅支持文件,不支持目录,但可以使用通配
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册