diff --git a/onnx2fluid/onnx2fluid/cmdline.py b/onnx2fluid/onnx2fluid/cmdline.py index ba8b22bcf5293e70fa642a3076523a4e016c2037..4294490224a4762301d102a2b6b45970cbb5fcf5 100644 --- a/onnx2fluid/onnx2fluid/cmdline.py +++ b/onnx2fluid/onnx2fluid/cmdline.py @@ -33,7 +33,7 @@ def main(**kwargs): from .conversion import convert logger = logging.getLogger('onnx2fluid') - debug = kwargs.get('debug', False) + # debug = kwargs.get('debug', False) # prepare arguments filename = kwargs.pop('model')[0] @@ -65,8 +65,7 @@ def main(**kwargs): from .validation import validate save_inference_model = infer_inputs is not None - inference_input_names = infer_inputs.split( - ',') if infer_inputs else None + inference_input_names = infer_inputs and infer_inputs.split(',') logger.info('starting validation on desc ...') passed &= validate(shutil.os.path.join(save_dir, '__model__'), @@ -85,7 +84,7 @@ def main(**kwargs): **kwargs) if not passed: - logger.error('validation failed, exit') + logger.fatal('validation failed, exit') return # create zip file diff --git a/onnx2fluid/onnx2fluid/conversion.py b/onnx2fluid/onnx2fluid/conversion.py index 0c440a3955020a6524a61c9a3e6ed75efd6534bf..b566936430de98103ca4671f125e620fae489b5f 100644 --- a/onnx2fluid/onnx2fluid/conversion.py +++ b/onnx2fluid/onnx2fluid/conversion.py @@ -34,15 +34,12 @@ def convert(onnx_model_filename, from onnx.checker import ValidationError from onnx.checker import check_model - from onnx.utils import polish_model from onnx.version_converter import convert_version from .onnx_utils import DEFAULT_OP_DOMAIN from .onnx_utils import graph_ops, graph_weights from .onnx_utils import inferred_model_value_info - from .onnx_utils import optimize_model_skip_op_for_inference - from .onnx_utils import optimize_model_strip_initializer - from .onnx_utils import optimize_model_cast, optimize_model_slice + from .onnx_utils import polish_model from .writer import Program, Writer from .writer import make_var_name @@ -56,14 +53,12 @@ def convert(onnx_model_filename, logger.info('checking model ...') check_model(onnx_model) if onnx_opset_version is None: # WORKAROUND: RuntimeError: No Adapter For OP - logger.debug('assumed opset version: %d', - DEFAULT_ONNX_OPSET_VERSION) logger.warning( 'opset conversion skipped for onnx_opset_pedantic is OFF') + logger.info('assumed opset version: %d', DEFAULT_ONNX_OPSET_VERSION) else: - logger.debug('using opset version: %d', onnx_opset_version) + logger.info('using opset version: %d', onnx_opset_version) onnx_model = convert_version(onnx_model, onnx_opset_version) - onnx_model = polish_model(onnx_model) except ValidationError as e: if onnx_opset_pedantic: raise e @@ -75,10 +70,7 @@ def convert(onnx_model_filename, # onnx model optimization logger.info('model has %d ops', len(onnx_model.graph.node)) logger.info('optimizing model ...') - onnx_model = optimize_model_skip_op_for_inference(onnx_model) - onnx_model = optimize_model_strip_initializer(onnx_model) - onnx_model = optimize_model_cast(onnx_model) - onnx_model = optimize_model_slice(onnx_model) + onnx_model = polish_model(onnx_model) # prepare filesystem shutil.rmtree(save_dir, ignore_errors=True) @@ -87,9 +79,8 @@ def convert(onnx_model_filename, # DEBUG: if debug: - model = onnx.shape_inference.infer_shapes(onnx_model) debug_model_filename, _ = shutil.os.path.splitext(onnx_model_filename) - onnx.save(model, debug_model_filename + '.optimized_and_inffered.onnx') + onnx.save(onnx_model, debug_model_filename + '.polished.onnx') # I/O instances onnx_graph = onnx_model.graph @@ -141,11 +132,11 @@ def convert(onnx_model_filename, logger.info('%d ops in, %d ops out', len(onnx_graph.node), len(fluid_program.op_descs)) - # type-shape inference + # type-shape info copy for name, value_info in graph_value_infos.items(): var_name = make_var_name(name) fluid_program.VarTypeShapeInfo(var_name, value_info, - remove_batch=False) # shape-infer only + remove_batch=False) # bad_var_names = [] for var_name, var_desc in fluid_program.var_descs.items(): if not var_desc.type.lod_tensor.HasField('tensor'): @@ -155,8 +146,8 @@ def convert(onnx_model_filename, ', '.join(bad_var_names[:5])) logger.warning('this causes little problem for PaddlePaddle, ' 'but Paddle Mobile may not infer correctly') - logger.warning('please consider running onnx2fluid.validation with -i ' - 'to invoke PaddlePaddle type-shape inference') + logger.warning('please consider running validation with -i ' + 'to invoke type-shape inference in PaddlePaddle') # weight writer for name, weight in graph_weights(onnx_graph): diff --git a/onnx2fluid/onnx2fluid/onnx_utils.py b/onnx2fluid/onnx2fluid/onnx_utils.py index 19e0c73dfbbb4b2188e84edd963131f620613498..b6f6dc39f2ec367f575ff4f8321cd57e297ce329 100644 --- a/onnx2fluid/onnx2fluid/onnx_utils.py +++ b/onnx2fluid/onnx2fluid/onnx_utils.py @@ -11,9 +11,11 @@ from __future__ import division import logging import numpy as np import onnx +import onnx.optimizer as optimizer from collections import OrderedDict as Dict # as default dict -from onnx.helper import get_attribute_value, make_attribute +from onnx.checker import check_model +from onnx.helper import get_attribute_value, make_attribute, strip_doc_string from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE from onnx.numpy_helper import to_array from onnx.shape_inference import infer_shapes @@ -23,14 +25,16 @@ logger = logging.getLogger(__name__) __all__ = [ 'print_pb_structure', 'build_value_refs', + 'tensor_dtype', + 'tensor_shape', 'node_attrs', 'node_topo', 'node_iter', - 'tensor_dtype', - 'tensor_shape', 'graph_ops', 'graph_weights', 'inferred_model_value_info', + 'polish_model', + 'polish_and_save', 'optimize_model_skip_op_for_inference', 'optimize_model_strip_initializer', 'optimize_model_cast', @@ -110,7 +114,7 @@ def tensor_shape(tensor): get ONNX tensor shape """ - return [dim.dim_value for dim in tensor.type.tensor_type.shape.dim] + return tuple([dim.dim_value for dim in tensor.type.tensor_type.shape.dim]) def node_attrs(node): @@ -195,10 +199,7 @@ def node_iter(nodes, indices=None): generator for ONNX node graph with given indices """ - if indices is None: - indices = range(len(nodes)) - - for index in indices: + for index in indices or range(len(nodes)): node = nodes[index] name = node.name domain = node.domain @@ -306,6 +307,48 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs): return processed +def polish_model(model, extras=True): + """ + polish_model enhanced for inference + """ + + check_model(model) + strip_doc_string(model) + passes = optimizer.get_available_passes() + passes = list(filter(lambda name: not name.startswith('split_'), passes)) # + logger.debug('builtin optimizations to perform in ONNX:\n\t%s', passes) + model = optimizer.optimize(model, passes=passes) + if extras: + for optimize in ( + optimize_model_skip_op_for_inference, + optimize_model_strip_initializer, + optimize_model_cast, + optimize_model_slice, + ): + model = optimize(model) + model = infer_shapes(model) + check_model(model) + return model + + +def polish_and_save(model_filename, + suffix='.polished', + save_filename=None, + *args, + **kwargs): + """ + run polish_model and save + """ + + model = onnx.load(model_filename) + model = polish_model(model, *args, **kwargs) + save_filename = save_filename or model_filename.replace( + '.onnx', suffix + '.onnx') + onnx.save(model, save_filename) + logger.info('polished model saved to: %s', save_filename) + return save_filename + + def optimize_model_skip_op_for_inference(model, op_list=None): """ skip ops can be bypassed for inference @@ -326,7 +369,7 @@ def optimize_model_skip_op_for_inference(model, op_list=None): if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''): continue op_type = node.op_type - if not (op_type in op_list): + if op_type not in op_list: continue if op_type in ('Dropout', ): @@ -590,22 +633,16 @@ if __name__ == '__main__': level=logging.DEBUG, ) - from onnx.checker import check_model - from onnx.utils import polish_model from onnx.version_converter import convert_version - model = onnx.load('../examples/t1.onnx') + model = onnx.load('/tmp/export.onnx') print_pb_structure(model, loop_iterative=False) check_model(model) model = convert_version(model, 9) - model = optimize_model_skip_op_for_inference(model) - model = optimize_model_strip_initializer(model) - model = optimize_model_cast(model) - model = optimize_model_slice(model) model = polish_model(model) - onnx.save(model, '/tmp/optimized.onnx') + onnx.save(model, '/tmp/export.optimized.onnx') graph = model.graph value_info = inferred_model_value_info(model) @@ -617,23 +654,23 @@ if __name__ == '__main__': logger.info('ops:') for name, domain, op_type, _, _, attrs in graph_ops(graph, topo='forward'): - logger.info('%s %s::%s: %s', name, domain, op_type, attrs) + logger.info('- \t%s %s::%s: %s', name, domain, op_type, attrs) logger.info('weights:') for name, array in graph_weights(graph): weights.append(name) - logger.info('%s: %s', name, array.shape) + logger.info('- \t%s: %s', name, array.shape) logger.info('inputs:') external_inputs = [] for name in inputs: if name not in weights: external_inputs.append(name) - logger.info('%s: %s', name, value_info[name]['shape']) + logger.info('- \t%s: %s', name, value_info[name]['shape']) logger.info('outputs:') external_outputs = [] for name in outputs: if name not in weights: external_outputs.append(name) - logger.info('%s: %s', name, value_info[name]['shape']) + logger.info('- \t%s: %s', name, value_info[name]['shape']) diff --git a/onnx2fluid/onnx2fluid/symbolic.py b/onnx2fluid/onnx2fluid/symbolic.py index 16733169990283714499d42c46cac3834fd55a4a..a1cee35642c0ba108e60e097b940c68c773335f9 100644 --- a/onnx2fluid/onnx2fluid/symbolic.py +++ b/onnx2fluid/onnx2fluid/symbolic.py @@ -203,8 +203,7 @@ def _check_embeddable(value_infos, *val_names): keyword = 'get_weight' for val_name in val_names: if keyword not in value_infos[val_name]: - _logger.warning('parameter %s not embeddable for some ops', - val_name) + _logger.warning('parameter %s not embeddable', val_name) return False return True @@ -240,9 +239,9 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs): fluid_attrs = default_attrs.copy() fluid_attrs.update(mapped_attrs) # as new attrs - val_inps = inputs if input_perm is None else map(lambda i: inputs[i], + val_inps = inputs if input_perm is None else map(inputs.__getitem__, input_perm) - val_outs = outputs if output_perm is None else map(lambda i: outputs[i], + val_outs = outputs if output_perm is None else map(outputs.__getitem__, output_perm) var_inps = [_make_var_name(val) for val in val_inps] var_outs = [_make_var_name(val) for val in val_outs] @@ -578,7 +577,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''): 1] == 1, 'only scale on (NC)HW supported' assert scales[2] == scales[ 3], 'only aspect-ratio-invariant scale supported' - scale = None if scales is None else scales[2] + scale = scales and scales[2] # try input shape if scale is None: assert out_shape_, 'neither scales nor output shape is available' @@ -717,6 +716,10 @@ def BatchNormalization(prog, if embed_params: embed_params = _check_embeddable(value_infos, val_scale, val_b, val_mean, val_var) + if not embed_params and name: + _logger.warning('for op %s(%s -> BatchNormalization -> %s)', name, + inputs, outputs) + _logger.warning('broken Python code will be generated') if embed_params: assert name != '' var_scale = name + '.w_0' @@ -875,7 +878,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): if shape is None: shape = list(value.shape) _logger.warning( - 'in (Constant -> %s): ' + 'in op (Constant -> %s): ' 'attribute "shape" of %s not inferred, ' 'using value as 1-D tensor may lead to fails', outputs, val_output) @@ -986,6 +989,10 @@ def Conv(prog, if embed_params: embed_params = (_check_embeddable(value_infos, val_w) and not has_bias or _check_embeddable(value_infos, val_b)) + if not embed_params and name: + _logger.warning('for op %s(%s -> Conv -> %s)', name, inputs, + outputs) + _logger.warning('broken Python code will be generated') if embed_params: assert name != '' var_w = name + '.w_0' @@ -1099,6 +1106,10 @@ def ConvTranspose(prog, if embed_params: embed_params = (_check_embeddable(value_infos, val_w) and not has_bias or _check_embeddable(value_infos, val_b)) + if not embed_params and name: + _logger.warning('for op %s(%s -> ConvTranspose -> %s)', name, + inputs, outputs) + _logger.warning('broken Python code will be generated') if embed_params: assert name != '' var_w = name + '.w_0' @@ -1167,23 +1178,6 @@ def ConvTranspose(prog, prog.VarDesc(var_y) -# should not appear -#def Dropout( -# prog, inputs, outputs, value_infos, -# *args, **kwargs): -# """ -# onnx::Dropout-7:9 -# """ -# -# val_data, = inputs -# val_output, = outputs[:1] -# -# _assign(prog, -# dict([(val_output, val_data)]), -# value_infos, -# ) - - def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): """ onnx::Gemm-9: @@ -1236,7 +1230,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): if vm_dtype is None: vm_dtype = _np.dtype('float32') _logger.warning( - 'in %s(%s -> Gemm -> %s): ' + 'in op %s(%s -> Gemm -> %s): ' 'attribute "beta" seems to be an interger, ' 'however dtype can not be inferred, ' 'still use float32', name, inputs, outputs) @@ -1425,6 +1419,10 @@ def PRelu(prog, name_attr = ', name={}'.format(repr(name)) if name else '' if embed_params: embed_params = _check_embeddable(value_infos, val_slope) + if not embed_params and name: + _logger.warning('for op %s(%s -> PRelu -> %s)', name, inputs, + outputs) + _logger.warning('broken Python code will be generated') if embed_params: assert name != '' var_slope = name + '.w_0' @@ -1487,7 +1485,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): if shape is None: shape = [1, -1] # who knows _logger.warning( - 'in %s(%s -> Reshape -> %s): ' + 'in op %s(%s -> Reshape -> %s): ' 'input "shape" not inferred, use [1, -1] as dummy value, ' 'the behavior of Paddle fluid maybe undefined', name, inputs, outputs) diff --git a/onnx2fluid/onnx2fluid/torch_export_helper.py b/onnx2fluid/onnx2fluid/torch_export_helper.py index 7a0fd6031433e989fafc142a40e9bad46df9f41f..39f034317c76b3372c3a969df33dc20fe264ddaf 100644 --- a/onnx2fluid/onnx2fluid/torch_export_helper.py +++ b/onnx2fluid/onnx2fluid/torch_export_helper.py @@ -9,22 +9,50 @@ Created on Fri Mar 22 11:22:46 2019 import numpy as np import torch -from collections import OrderedDict as Dict +from collections import OrderedDict +from typing import ( + TypeVar, + Any, + Generic, + Iterable, + List, + Mapping, + Optional, + Sequence, + Text, + Tuple, + Union, +) +__all__ = [ + 'export_data', + 'export_onnx_with_validation', +] -def ensure_list(obj): +my_dict = OrderedDict + +KT = TypeVar('KT') +VT = TypeVar('VT') + + +class MyDict(my_dict, Generic[KT, VT]): + pass + + +def ensure_list(obj: Union[object, Sequence[object]]) -> List[object]: if isinstance(obj, (list, tuple, set)): return list(obj) return [obj] -def ensure_tuple(obj): +def ensure_tuple(obj: Union[object, Sequence[object]]) -> Tuple[object, ...]: if isinstance(obj, (tuple, list, set)): return tuple(obj) return (obj, ) -def flatten_list(obj, out=None): +def flatten_list(obj: List[Union[object, List[object]]], + out: Optional[List[object]] = None) -> List[object]: assert isinstance(obj, list), 'list type required' if out is None: @@ -37,21 +65,21 @@ def flatten_list(obj, out=None): return out -def export_data(state_dict, prefix=''): +def export_data(state_dict: Mapping[Text, Any], prefix: Text = '') -> None: """ export binary data with meta text for raw C++ inference engines """ - def str_(obj): + def str_(obj: object) -> Text: if isinstance(obj, (tuple, list, set)): return str(obj)[1:-1].replace(' ', '') return str(obj) prefix_ = prefix + ('_' if prefix else '') - fp = open('{}.txt'.format(prefix if prefix else 'meta'), 'w') + fp = open('{}.txt'.format(prefix or 'meta'), 'w') for key, value in state_dict.items(): data = None - if torch and torch.is_tensor(value): + if torch.is_tensor(value): data = value.data.cpu().numpy() elif isinstance(value, np.ndarray): data = value @@ -64,30 +92,33 @@ def export_data(state_dict, prefix=''): fp.close() -def export_onnx_with_validation(model, - inputs, - export_basepath, - input_names=None, - output_names=None, - use_npz=True, - *args, - **kwargs): +def export_onnx_with_validation( + model: torch.nn.Module, + inputs: Sequence[Union[torch.Tensor, Sequence[object]]], + export_basepath: Text, + input_names: Optional[List[Text]] = None, + output_names: Optional[List[Text]] = None, + use_npz: bool = True, + *args, + **kwargs) -> Sequence[Union[torch.Tensor, Sequence[object]]]: """ export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file """ is_tuple_or_list = lambda x: isinstance(x, (tuple, list)) - def tensors_to_arrays(tensors): + def tensors_to_arrays(tensors: Union[torch.Tensor, Iterable[ + Union[torch.Tensor, Iterable[Any]]]], ) -> List[np.ndarray]: if torch.is_tensor(tensors): return tensors.data.cpu().numpy() - arrays = [] - for tensor in tensors: - arrays.append(tensors_to_arrays(tensor)) - return arrays - - def zip_dict(keys, values): - ret = Dict() + return list(map(tensors_to_arrays, tensors)) + + def zip_dict( + keys: Union[Iterable[Any], None], + values: Sequence[Union[Any, Sequence[Any]]], + ) -> MyDict[Text, Union[object, MyDict[Text, object]]]: + keys = keys or range(len(values)) + ret = my_dict() for idx, (key, value) in enumerate(zip(keys, values)): is_key_list = is_tuple_or_list(key) is_value_list = is_tuple_or_list(value) @@ -102,19 +133,48 @@ def export_onnx_with_validation(model, outputs = torch.onnx.export(model, torch_inputs, export_basepath + '.onnx', - input_names=flatten_list(input_names), - output_names=flatten_list(output_names), + input_names=(None if input_names is None else + flatten_list(input_names)), + output_names=(None if output_names is None else + flatten_list(output_names)), *args, **kwargs) if outputs is None: # WORKAROUND: for torch.onnx - outputs = model(*inputs) + training = kwargs.get('training', False) + with torch.onnx.set_training(model, training): + outputs = model(*inputs) torch_outputs = ensure_tuple(outputs) inputs = zip_dict(input_names, tensors_to_arrays(torch_inputs)) outputs = zip_dict(output_names, tensors_to_arrays(torch_outputs)) if use_npz: - np.savez(export_basepath + '.npz', inputs=inputs, outputs=outputs) + np.savez( + export_basepath + '.npz', + inputs=inputs, + outputs=outputs, + ) else: np.save(export_basepath + '.npy', - np.array(Dict(inputs=inputs, outputs=outputs))) + np.asarray(my_dict(inputs=inputs, outputs=outputs)), + allow_pickle=True) + return torch_outputs + + +if __name__ == '__main__': + from torchvision.models import resnet18 as net + + model = net() + xb = torch.rand((1, 3, 224, 224)) + export_onnx_with_validation( + model, + (xb, ), + '/tmp/export', + input_names=[ + 'image', + ], + output_names=[ + 'prob', + ], + use_npz=True, + ) diff --git a/onnx2fluid/onnx2fluid/validation.py b/onnx2fluid/onnx2fluid/validation.py index 223e1116fdd652c525a6f713e9d813eef7478c1c..22e78d9e604568931d9077802d37682ae4b7c6ae 100644 --- a/onnx2fluid/onnx2fluid/validation.py +++ b/onnx2fluid/onnx2fluid/validation.py @@ -8,6 +8,13 @@ Created on Fri Mar 22 12:17:19 2019 import importlib, logging, os, sys +logger = logging.getLogger(__name__) + +__all__ = [ + 'fluid_prog_shape_infer', + 'validate', +] + def flatten_dict(obj, out=None): assert isinstance(obj, dict), 'dict type required' @@ -29,6 +36,42 @@ def ensure_list(obj): return [obj] +def fluid_prog_shape_infer(prog): + """ + additional type-shape inference for fluid program + """ + + import paddle.fluid as fluid + + assert isinstance(prog, fluid.framework.Program) + + logger.info('performing type-shape inference ...') + for block in prog.blocks: + block_desc = block.desc + + for idx_op in range(block_desc.op_size()): + op_desc = block_desc.op(idx_op) + if op_desc.type() in ('feed', 'fetch'): + continue + + op_desc.infer_var_type(block_desc) + op_desc.infer_shape(block_desc) + + for var_name, var in block.vars.items(): + var_desc = var.desc + if var_desc.type() != fluid.core.VarDesc.VarType.LOD_TENSOR: + continue + + # WORKAROUND: dirty way to give dtype to partial-infered vars + # which could not be cleared! + try: + var.to_string(True) + except ValueError: + var_desc.set_dtype(fluid.core.VarDesc.VarType.FP32) + logger.debug('dtype of var %s not inferred, float32 assumed', + var_name) + + def validate(fluid_model_filename, golden_data_filename='', atol=1e-3, @@ -53,12 +96,12 @@ def validate(fluid_model_filename, # load model fluid_model_dir, basename = os.path.split(fluid_model_filename) if basename == '__model__': # is desc program - logger.debug('using desc file %s', basename) + logger.info('using desc file %s', basename) prog, _, var_outs = fluid.io.load_inference_model(fluid_model_dir, exe) out_names = var_outs # HINT: pass var if fetch ops already created logger.info('model load passed') elif basename.endswith('.py'): # is Python code - logger.debug('using code file %s', basename) + logger.info('using code file %s', basename) module_name, _ = os.path.splitext(basename) sys_path = sys.path.copy() sys.path.append(fluid_model_dir) @@ -91,18 +134,28 @@ def validate(fluid_model_filename, if golden_data_filename: logger.info('using golden data %s', golden_data_filename) if golden_data_filename.endswith('.npz'): - test_data = np.load(golden_data_filename, encoding='bytes') + test_data = np.load( + golden_data_filename, + encoding='bytes', + allow_pickle=True, + ) input_data = test_data['inputs'].tolist() output_data = test_data['outputs'].tolist() else: - test_data = np.load(golden_data_filename, encoding='bytes').tolist() + test_data = np.load( + golden_data_filename, + encoding='bytes', + allow_pickle=True, + ).tolist() input_data = test_data['inputs'] output_data = test_data['outputs'] + input_data = flatten_dict(input_data) output_data = flatten_dict(output_data) input_names = input_data.keys() - logger.info('found %d I/O golden data, starting test ...', - len(input_data) + len(output_data)) + output_names = output_data.keys() + logger.info('with %d inputs and %d outputs', len(input_data), + len(output_data)) else: assert inference_input_names, 'input names required for type-shape inference' @@ -111,25 +164,7 @@ def validate(fluid_model_filename, # type-shape inference and re-save if save_inference_model: - for block in prog.blocks: - block_desc = block.desc - for idx_op in range(block_desc.op_size()): - op_desc = block_desc.op(idx_op) - if op_desc.type() in ('feed', 'fetch'): - continue - op_desc.infer_var_type(block_desc) - op_desc.infer_shape(block_desc) - for var_name, var in block.vars.items(): - var_desc = var.desc - if var_desc.type() != fluid.core.VarDesc.VarType.LOD_TENSOR: - continue - # WORKAROUND: dirty way to give dtype to partial-infered vars - # which could not be cleared! - try: - var.to_string(True) - except ValueError: - var_desc.set_dtype(fluid.core.VarDesc.VarType.FP32) - + fluid_prog_shape_infer(prog) fluid.io.save_inference_model(fluid_model_dir, input_names, var_outs, @@ -151,7 +186,7 @@ def validate(fluid_model_filename, # validate passed = True for (name, truth), output in zip(output_data.items(), outputs): - logger.info('testing output {} ...'.format(name)) + logger.info('testing on output {} ...'.format(name)) try: np.testing.assert_allclose(output, truth, @@ -162,10 +197,7 @@ def validate(fluid_model_filename, except AssertionError as e: passed = False logger.error('failed: %s\n', e) - if passed: - logger.info('accuracy passed') - else: - logger.info('accuracy not passed') + logger.info('accuracy %spassed', '' if passed else 'not ') return passed