diff --git a/onnx2fluid/examples/convert_data_npz.py b/onnx2fluid/examples/convert_data_npz.py index 82486de51001e355add8e43b43d312adf18b19e8..80c921a2e77eda5b2cc7d0a082560c4a39c7a8f9 100644 --- a/onnx2fluid/examples/convert_data_npz.py +++ b/onnx2fluid/examples/convert_data_npz.py @@ -14,14 +14,14 @@ from collections import OrderedDict as Dict def make_var_name(name): """ - make a valid variable name in Python code - """ + make a valid variable name in Python code + """ if name == '': return '_' if name[0].isdigit(): return 'var_' + name - for s in ' \\|/:': # + for s in ' \\|/:-': # name = name.replace(s, '_') if name.startswith('_'): name = 'var' + name diff --git a/onnx2fluid/examples/convert_data_pb.py b/onnx2fluid/examples/convert_data_pb.py index 16484a8cb27f3a20e42666e967c07ad58ba87fc0..f07b73c0b5d205b909a6315a0fb1c7e0213d4ad9 100644 --- a/onnx2fluid/examples/convert_data_pb.py +++ b/onnx2fluid/examples/convert_data_pb.py @@ -17,14 +17,14 @@ from glob import glob def make_var_name(name): """ - make a valid variable name in Python code - """ + make a valid variable name in Python code + """ if name == '': return '_' if name[0].isdigit(): return 'var_' + name - for s in ' \\|/:': # + for s in ' \\|/:-': # name = name.replace(s, '_') if name.startswith('_'): name = 'var' + name diff --git a/onnx2fluid/examples/onnx_model_zoo.sh b/onnx2fluid/examples/onnx_model_zoo.sh index b544b6e479f03199f2abbce0caee88b9079c8413..ca7fa9bf2c648d93f95c7cc5461fa677dc1c6a0e 100755 --- a/onnx2fluid/examples/onnx_model_zoo.sh +++ b/onnx2fluid/examples/onnx_model_zoo.sh @@ -311,7 +311,7 @@ resnet100_arcface() echo "extracting ..." tar xf "$fn_tar" - python -m onnx2fluid -o /tmp/export/ "$fn_model" -y + python -m onnx2fluid $convert_flags "$fn_model" -y for pb_dir in "$bn_tar"/*/ do echo "converting $pb_dir ..." diff --git a/onnx2fluid/onnx2fluid/__main__.py b/onnx2fluid/onnx2fluid/__main__.py index 17d5b4307a787a025c35b6b56683aba2fd8679bb..f09f63e331c83a5e6719f2e6396640cb33dba015 100644 --- a/onnx2fluid/onnx2fluid/__main__.py +++ b/onnx2fluid/onnx2fluid/__main__.py @@ -95,6 +95,14 @@ parser.add_argument( default=1e-2, help='assertion relative tolerance for validation', ) +parser.add_argument( + '--infer_inputs', + '-i', + nargs='?', + default=None, + const='', + help='perform type-shape inference with given input names and re-save model', +) args = parser.parse_args() logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s' diff --git a/onnx2fluid/onnx2fluid/cmdline.py b/onnx2fluid/onnx2fluid/cmdline.py index 801388a45db3d3586cb9c029d099319e1ff2c092..ba8b22bcf5293e70fa642a3076523a4e016c2037 100644 --- a/onnx2fluid/onnx2fluid/cmdline.py +++ b/onnx2fluid/onnx2fluid/cmdline.py @@ -60,19 +60,28 @@ def main(**kwargs): # validate passed = True golden_data_filename = kwargs.pop('test_data', '') - if golden_data_filename: + infer_inputs = kwargs.pop('infer_inputs', None) + if golden_data_filename or infer_inputs: from .validation import validate + save_inference_model = infer_inputs is not None + inference_input_names = infer_inputs.split( + ',') if infer_inputs else None + logger.info('starting validation on desc ...') passed &= validate(shutil.os.path.join(save_dir, '__model__'), - golden_data_filename, **kwargs) + golden_data_filename=golden_data_filename, + save_inference_model=save_inference_model, + inference_input_names=inference_input_names, + **kwargs) logger.info('starting validation on code ...') # this re-generate desc proto with Python code when debug on passed &= validate(shutil.os.path.join(save_dir, model_basename), - golden_data_filename, + golden_data_filename=golden_data_filename, model_func_name=model_func_name, - save_inference_model=debug, + save_inference_model=save_inference_model, + inference_input_names=inference_input_names, **kwargs) if not passed: diff --git a/onnx2fluid/onnx2fluid/conversion.py b/onnx2fluid/onnx2fluid/conversion.py index 33957a6b0543c401d55539d84c96f77f26a2e449..0c440a3955020a6524a61c9a3e6ed75efd6534bf 100644 --- a/onnx2fluid/onnx2fluid/conversion.py +++ b/onnx2fluid/onnx2fluid/conversion.py @@ -27,8 +27,8 @@ def convert(onnx_model_filename, debug=False, **kwargs): """ - convert an ONNX model to Paddle fluid Python code and desc pb - """ + convert an ONNX model to Paddle fluid Python code and desc pb + """ import onnx @@ -141,23 +141,22 @@ def convert(onnx_model_filename, logger.info('%d ops in, %d ops out', len(onnx_graph.node), len(fluid_program.op_descs)) - # shape-inference + # type-shape inference for name, value_info in graph_value_infos.items(): var_name = make_var_name(name) - fluid_program.VarTypeInfo(var_name, value_info, - remove_batch=False) # shape-infer only + fluid_program.VarTypeShapeInfo(var_name, value_info, + remove_batch=False) # shape-infer only bad_var_names = [] for var_name, var_desc in fluid_program.var_descs.items(): if not var_desc.type.lod_tensor.HasField('tensor'): bad_var_names.append(var_name) if len(bad_var_names) > 0: - logger.warning('type info not infered for var %s ...', + logger.warning('type-shape not infered for var %s ...', ', '.join(bad_var_names[:5])) logger.warning('this causes little problem for PaddlePaddle, ' 'but Paddle Mobile may not infer correctly') - logger.warning( - 'please consider adding option -d to invoke PaddlePaddle shape-inference' - ) + logger.warning('please consider running onnx2fluid.validation with -i ' + 'to invoke PaddlePaddle type-shape inference') # weight writer for name, weight in graph_weights(onnx_graph): @@ -233,13 +232,9 @@ def convert(onnx_model_filename, logger.info('conversion finished') -if __name__ == '__main__': - del convert - +def main(): import argparse - from onnx2fluid.conversion import convert - parser = argparse.ArgumentParser( description='onnx2fluid.convert', formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -310,3 +305,11 @@ if __name__ == '__main__': onnx_opset_pedantic=pedantic, onnx_skip_version_conversion=skip_version_conversion, debug=debug) + + +if __name__ == '__main__': + del convert + + from onnx2fluid.conversion import convert + + main() diff --git a/onnx2fluid/onnx2fluid/onnx_utils.py b/onnx2fluid/onnx2fluid/onnx_utils.py index 67d9fe854ca5b192a4917b7a8b218e63e119eb0c..19e0c73dfbbb4b2188e84edd963131f620613498 100644 --- a/onnx2fluid/onnx2fluid/onnx_utils.py +++ b/onnx2fluid/onnx2fluid/onnx_utils.py @@ -44,8 +44,8 @@ DEFAULT_OP_DOMAIN = 'ai.onnx' def print_pb_structure(message, loop_iterative=False, depth=0): """ - print pb fields in its structure - """ + print pb fields in its structure + """ if hasattr(message, 'DESCRIPTOR') and hasattr(message.DESCRIPTOR, 'fields'): for field in message.DESCRIPTOR.fields: @@ -65,8 +65,8 @@ def print_pb_structure(message, loop_iterative=False, depth=0): def build_value_refs(nodes): """ - build op reference of inputs and outputs - """ + build op reference of inputs and outputs + """ input_refs = Dict() output_refs = Dict() @@ -80,8 +80,8 @@ def build_value_refs(nodes): def get_attribute_value2(attr): """ - get_attribute_value enhanced - """ + get_attribute_value enhanced + """ if attr.type == onnx.AttributeProto.TENSOR: dtype = np.dtype(TENSOR_TYPE_TO_NP_TYPE[attr.t.data_type]) @@ -99,24 +99,24 @@ def get_attribute_value2(attr): def tensor_dtype(tensor): """ - get ONNX tensor in np.dtype - """ + get ONNX tensor in np.dtype + """ return TENSOR_TYPE_TO_NP_TYPE[tensor.type.tensor_type.elem_type] def tensor_shape(tensor): """ - get ONNX tensor shape - """ + get ONNX tensor shape + """ return [dim.dim_value for dim in tensor.type.tensor_type.shape.dim] def node_attrs(node): """ - convert ONNX node attributes to dict - """ + convert ONNX node attributes to dict + """ return {attr.name: get_attribute_value2(attr) for attr in node.attribute} # dict @@ -124,8 +124,8 @@ def node_attrs(node): def node_topo(nodes, topo='default'): """ - build indices with given topology to an ONNX node graph - """ + build indices with given topology to an ONNX node graph + """ if topo == 'default': return list(range(len(nodes))) @@ -192,8 +192,8 @@ def node_topo(nodes, topo='default'): def node_iter(nodes, indices=None): """ - generator for ONNX node graph with given indices - """ + generator for ONNX node graph with given indices + """ if indices is None: indices = range(len(nodes)) @@ -210,7 +210,7 @@ def node_iter(nodes, indices=None): if name == '': name = 'op_' + str(index) else: # make_op_name - for s in ' \\|/:': # + for s in ' \\|/:-': # name = name.replace(s, '_') if domain == '': domain = DEFAULT_OP_DOMAIN @@ -220,8 +220,8 @@ def node_iter(nodes, indices=None): def graph_ops(graph, topo='default'): """ - generator for ONNX node graph with given topology - """ + generator for ONNX node graph with given topology + """ if not isinstance(graph, onnx.GraphProto): logger.error('graph is not a GraphProto instance') @@ -232,8 +232,8 @@ def graph_ops(graph, topo='default'): def graph_weights(graph): """ - generator for weights of an ONNX model - """ + generator for weights of an ONNX model + """ if not isinstance(graph, onnx.GraphProto): logger.error('graph is not a GraphProto instance') @@ -247,8 +247,8 @@ def graph_weights(graph): def inferred_model_value_info(model): """ - collect value/type info for an ONNX model - """ + collect value/type info for an ONNX model + """ model = infer_shapes(model) graph = model.graph @@ -278,8 +278,8 @@ def inferred_model_value_info(model): def skip_node_forward(nodes, src_output_name, dst_input_name, input_refs): """ - skip nodes between src_output_name -> dst_input_name and connect this pair - """ + skip nodes between src_output_name -> dst_input_name and connect this pair + """ processed = 0 for next_idx in input_refs[src_output_name]: @@ -293,8 +293,8 @@ def skip_node_forward(nodes, src_output_name, dst_input_name, input_refs): def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs): """ - skip nodes between dst_output_name -> src_input_name and connect this pair - """ + skip nodes between dst_output_name -> src_input_name and connect this pair + """ processed = 0 for prev_idx in output_refs[src_input_name]: @@ -308,8 +308,8 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs): def optimize_model_skip_op_for_inference(model, op_list=None): """ - skip ops can be bypassed for inference - """ + skip ops can be bypassed for inference + """ if op_list is None: op_list = ('Dropout', 'Identity') @@ -369,8 +369,8 @@ def optimize_model_skip_op_for_inference(model, op_list=None): def optimize_model_strip_initializer(model, keep_input_only=True): """ - strip weights for inference - """ + strip weights for inference + """ nodes = model.graph.node input_refs, output_refs = build_value_refs(nodes) @@ -410,8 +410,8 @@ def optimize_model_strip_initializer(model, keep_input_only=True): def optimize_model_cast(model): """ - strip cascade and unecessary onnx::Cast-9: - """ + strip cascade and unecessary onnx::Cast-9: + """ nodes = model.graph.node input_refs, output_refs = build_value_refs(nodes) @@ -467,8 +467,8 @@ def optimize_model_cast(model): def optimize_model_slice(model): """ - strip cascade and unecessary onnx::Slice-1:9 - """ + strip cascade and unecessary onnx::Slice-1:9 + """ nodes = model.graph.node input_refs, output_refs = build_value_refs(nodes) diff --git a/onnx2fluid/onnx2fluid/symbolic.py b/onnx2fluid/onnx2fluid/symbolic.py index 814103854dee72e3cd1d56538f03d1857babff6c..16733169990283714499d42c46cac3834fd55a4a 100644 --- a/onnx2fluid/onnx2fluid/symbolic.py +++ b/onnx2fluid/onnx2fluid/symbolic.py @@ -41,74 +41,74 @@ DEFAULT_OP_MAPPING_FIELD_VALUES['FILL_NAME_FIELD'] = True DEFAULT_OP_MAPPING_VALUES = list(DEFAULT_OP_MAPPING_FIELD_VALUES.values()) DEFAULT_OP_MAPPING = { - ## nil ops ## - 'RandomUniform': - ['uniform_random', [], ['Out'], dict(high='max', low='min'), - dict(), None, None, False], - 'RandomNormal': - ['gaussian_random', [], ['Out'], dict(scale='std'), - dict(), None, None, False], - ## unary ops ## - 'Abs': ['abs', ['X'], ['Out']], - 'ArgMax': ['argmax', ['X'], ['Out'], dict(keepdims='')], - 'ArgMin': ['argmin', ['X'], ['Out'], dict(keepdims='')], - 'Ceil': ['ceil', ['X'], ['Out']], - 'Clip': ['clip', ['X'], ['Out']], # attrs bypassed - 'Cos': ['cos', ['X'], ['Out']], - 'Elu': ['elu', ['X'], ['Out']], - 'Exp': ['exp', ['X'], ['Out']], - 'Flatten': ['flatten', ['X'], ['Out']], # attrs bypassed, FIXME: emit flatten2 - 'Floor': ['floor', ['X'], ['Out']], - 'Gather': ['gather', ['X'], ['Out'], dict(axis='')], - 'LeakyRelu': ['leaky_relu', ['X'], ['Out']], - 'Log': ['log', ['X'], ['Out']], - 'LRN': ['lrn', ['X'], ['Out', 'MidOut'], dict(size='n', bias='k')], # - 'Reciprocal': ['reciprocal', ['X'], ['Out']], - 'Relu': ['relu', ['X'], ['Out']], - 'Selu': ['selu', ['X'], ['Out'], dict(gamma='scale')], - 'Shape': ['shape', ['X'], ['Out']], # FIXME: out is int64 vs int32 - 'Shrink': ['softshrink', ['X'], ['Out'], dict(bias='', labmd='')], - 'Sigmoid': ['sigmoid', ['X'], ['Out']], - 'Sin': ['sin', ['X'], ['Out']], - 'Squeeze': ['squeeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit squeeze2 - 'Softplus': ['softplus', ['X'], ['Out']], - # FIXME: default axis = -1, reshape required before and after - 'Softmax': ['softmax', ['X'], ['Out'], dict(axis='')], - 'Softsign': ['softsign', ['X'], ['Out']], - 'Sqrt': ['sqrt', ['X'], ['Out']], - 'Tanh': ['tanh', ['X'], ['Out']], - 'ThresholdedRelu': ['thresholded_relu', ['X'], ['Out'], dict(alpha='threshold')], - #'Transpose': ['transpose', ['X'], ['Out']], - 'Unsqueeze': ['unsqueeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit unsqueeze2 - ## binary ops ## - 'Add': ['elementwise_add', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], - #'AffineGrid': ['affine_grid', ['Theta'], ['Output'], dict(size='out_shape')], - 'And': ['logical_and', ['X', 'Y'], ['Out']], - 'Div': ['elementwise_div', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], - 'Equal': ['equal', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False], - 'Greater': ['less_than', ['X', 'Y'], ['Out'], dict(), dict(), [1, 0], None, False], - 'Less': ['less_than', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False], - 'MatMul': ['matmul', ['X', 'Y'], ['Out']], # defaults excluded for transpose_x vs transpose_X - 'Max': ['elementwise_max', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], - 'Min': ['elementwise_min', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], - 'Mul': ['elementwise_mul', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], - 'Not': ['logical_not', ['X', 'Y'], ['Out']], - 'OneHot': # assuming values=[0, 1], axis=-1 and drop them - ['one_hot', ['Input', 'Depth'], ['Out'], dict(axis=''), dict(), - [0, 1], None, False], - 'Or': ['logical_or', ['X', 'Y'], ['Out']], - 'Pow': ['elementwise_pow', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], # TODO: pow for scalar exponent - 'Sub': ['elementwise_sub', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], - 'Xor': ['logical_xor', ['X', 'Y'], ['Out']], - # reduce ops - 'ReduceMax': ['reduce_max', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')], - 'ReduceMean': ['reduce_mean', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')], - 'ReduceMin': ['reduce_min', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')], - 'ReduceProd': ['reduce_prod', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')], - 'ReduceSum': ['reduce_sum', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')], - # other ops - 'Scatter': ['scatter', ['X', 'Index', 'Updates'], ['Out']], - 'TopK': ['topk', ['X', 'K'], ['Out', 'Indices']], + ## nil ops ## + 'RandomUniform': + ['uniform_random', [], ['Out'], dict(high='max', low='min'), + dict(), None, None, False], + 'RandomNormal': + ['gaussian_random', [], ['Out'], dict(scale='std'), + dict(), None, None, False], + ## unary ops ## + 'Abs': ['abs', ['X'], ['Out']], + 'ArgMax': ['argmax', ['X'], ['Out'], dict(keepdims='')], + 'ArgMin': ['argmin', ['X'], ['Out'], dict(keepdims='')], + 'Ceil': ['ceil', ['X'], ['Out']], + 'Clip': ['clip', ['X'], ['Out']], # attrs bypassed + 'Cos': ['cos', ['X'], ['Out']], + 'Elu': ['elu', ['X'], ['Out']], + 'Exp': ['exp', ['X'], ['Out']], + 'Flatten': ['flatten', ['X'], ['Out']], # attrs bypassed, FIXME: emit flatten2 + 'Floor': ['floor', ['X'], ['Out']], + 'Gather': ['gather', ['X'], ['Out'], dict(axis='')], + 'LeakyRelu': ['leaky_relu', ['X'], ['Out']], + 'Log': ['log', ['X'], ['Out']], + 'LRN': ['lrn', ['X'], ['Out', 'MidOut'], dict(size='n', bias='k')], # + 'Reciprocal': ['reciprocal', ['X'], ['Out']], + 'Relu': ['relu', ['X'], ['Out']], + 'Selu': ['selu', ['X'], ['Out'], dict(gamma='scale')], + 'Shape': ['shape', ['X'], ['Out']], # FIXME: out is int64 vs int32 + 'Shrink': ['softshrink', ['X'], ['Out'], dict(bias='', labmd='')], + 'Sigmoid': ['sigmoid', ['X'], ['Out']], + 'Sin': ['sin', ['X'], ['Out']], + 'Squeeze': ['squeeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit squeeze2 + 'Softplus': ['softplus', ['X'], ['Out']], + # FIXME: default axis = -1, reshape required before and after + 'Softmax': ['softmax', ['X'], ['Out'], dict(axis='')], + 'Softsign': ['softsign', ['X'], ['Out']], + 'Sqrt': ['sqrt', ['X'], ['Out']], + 'Tanh': ['tanh', ['X'], ['Out']], + 'ThresholdedRelu': ['thresholded_relu', ['X'], ['Out'], dict(alpha='threshold')], + #'Transpose': ['transpose', ['X'], ['Out']], + 'Unsqueeze': ['unsqueeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit unsqueeze2 + ## binary ops ## + 'Add': ['elementwise_add', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], + #'AffineGrid': ['affine_grid', ['Theta'], ['Output'], dict(size='out_shape')], + 'And': ['logical_and', ['X', 'Y'], ['Out']], + 'Div': ['elementwise_div', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], + 'Equal': ['equal', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False], + 'Greater': ['less_than', ['X', 'Y'], ['Out'], dict(), dict(), [1, 0], None, False], + 'Less': ['less_than', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False], + 'MatMul': ['matmul', ['X', 'Y'], ['Out']], # defaults excluded for transpose_x vs transpose_X + 'Max': ['elementwise_max', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], + 'Min': ['elementwise_min', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], + 'Mul': ['elementwise_mul', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], + 'Not': ['logical_not', ['X', 'Y'], ['Out']], + 'OneHot': # assuming values=[0, 1], axis=-1 and drop them + ['one_hot', ['Input', 'Depth'], ['Out'], dict(axis=''), dict(), + [0, 1], None, False], + 'Or': ['logical_or', ['X', 'Y'], ['Out']], + 'Pow': ['elementwise_pow', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], # TODO: pow for scalar exponent + 'Sub': ['elementwise_sub', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], + 'Xor': ['logical_xor', ['X', 'Y'], ['Out']], + # reduce ops + 'ReduceMax': ['reduce_max', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')], + 'ReduceMean': ['reduce_mean', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')], + 'ReduceMin': ['reduce_min', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')], + 'ReduceProd': ['reduce_prod', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')], + 'ReduceSum': ['reduce_sum', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')], + # other ops + 'Scatter': ['scatter', ['X', 'Index', 'Updates'], ['Out']], + 'TopK': ['topk', ['X', 'K'], ['Out', 'Indices']], } DEFAULT_IOA_CONSTRAINTS = { @@ -146,14 +146,14 @@ DEFAULT_IOA_CONSTRAINTS = { def _make_var_name(name): """ - make a valid variable name in Python code and in filesystem - """ + make a valid variable name in Python code and in filesystem + """ if name == '': return '_' if name[0].isdigit(): return 'var_' + name - for s in ' \\|/:': # + for s in ' \\|/:-': # name = name.replace(s, '_') if name.startswith('_'): name = 'var' + name @@ -191,14 +191,24 @@ def _const_weight_or_none(value_infos, val_name): return None value_info = value_infos[val_name] const_value = value_info.get('const_value', None) - if const_value: + if const_value is not None: return const_value get_weight_func = value_info.get('get_weight', None) - if get_weight_func: + if get_weight_func is not None: return get_weight_func() return None +def _check_embeddable(value_infos, *val_names): + keyword = 'get_weight' + for val_name in val_names: + if keyword not in value_infos[val_name]: + _logger.warning('parameter %s not embeddable for some ops', + val_name) + return False + return True + + def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs): info = DEFAULT_OP_MAPPING[op_type] info.extend(DEFAULT_OP_MAPPING_VALUES[len(info):]) @@ -391,9 +401,9 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''): input_shape = _shape_or_none(value_infos, val_x) output_shape = _shape_or_none(value_infos, val_y) assert input_shape is not None or output_shape is not None, 'poolnd not inferred' # NC... - if input_shape: + if input_shape is not None: poolnd = len(input_shape) - 2 # NC... - elif output_shape: + elif output_shape is not None: poolnd = len(output_shape) - 2 # NC... assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported' @@ -568,7 +578,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''): 1] == 1, 'only scale on (NC)HW supported' assert scales[2] == scales[ 3], 'only aspect-ratio-invariant scale supported' - scale = scales[2] if scales else None + scale = None if scales is None else scales[2] # try input shape if scale is None: assert out_shape_, 'neither scales nor output shape is available' @@ -613,24 +623,24 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''): def AdaptiveAveragePool(prog, inputs, outputs, attrs, *args, name='', **kwargs): """ - aten::adaptive_avg_poolnd - """ + aten::adaptive_avg_poolnd + """ return _adaptive_pool(prog, 'avg', inputs, outputs, attrs, name=name) def AdaptiveMaxPool(prog, inputs, outputs, attrs, *args, name='', **kwargs): """ - aten::adaptive_max_poolnd - """ + aten::adaptive_max_poolnd + """ return _adaptive_pool(prog, 'max', inputs, outputs, attrs, name=name) def AffineGrid(prog, inputs, outputs, attrs, *args, name='', **kwargs): """ - aten::affine_grid - """ + aten::affine_grid + """ # I/O val_theta, = inputs @@ -672,8 +682,8 @@ def AveragePool(prog, *args, **kwargs): """ - onnx::AveragePool-10: - """ + onnx::AveragePool-10: + """ return _pool(prog, 'avg', inputs, outputs, attrs, value_infos, name=name) @@ -688,8 +698,8 @@ def BatchNormalization(prog, *args, **kwargs): """ - onnx::BatchNormalization-9: - """ + onnx::BatchNormalization-9: + """ # I/O val_x, val_scale, val_b, val_mean, val_var = inputs @@ -704,16 +714,19 @@ def BatchNormalization(prog, momentum = attrs.get('momentum', .9) # optional epsilon = attrs.get('epsilon', 1e-5) # optional name_attr = ', name={}'.format(repr(name)) if name else '' + if embed_params: + embed_params = _check_embeddable(value_infos, val_scale, val_b, + val_mean, val_var) if embed_params: assert name != '' var_scale = name + '.w_0' var_b = name + '.b_0' var_mean = name + '.w_1' var_var = name + '.w_2' - value_infos[val_scale].setdefault('embeded_as', []).append(var_scale) - value_infos[val_b].setdefault('embeded_as', []).append(var_b) - value_infos[val_mean].setdefault('embeded_as', []).append(var_mean) - value_infos[val_var].setdefault('embeded_as', []).append(var_var) + value_infos[val_scale]['embeded_as'].append(var_scale) + value_infos[val_b]['embeded_as'].append(var_b) + value_infos[val_mean]['embeded_as'].append(var_mean) + value_infos[val_var]['embeded_as'].append(var_var) param_attr = '' else: var_scale = _make_var_name(val_scale) @@ -760,8 +773,8 @@ def BatchNormalization(prog, def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): """ - onnx::Cast-9: - """ + onnx::Cast-9: + """ # I/O val_input, = inputs @@ -774,7 +787,7 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): if not isinstance(dtype, _np.dtype): # additional: possible np.dtype dtype = TENSOR_TYPE_TO_NP_TYPE[dtype] output_dtype = _dtype_or_none(value_infos, val_output) - if output_dtype: + if output_dtype is not None: assert dtype == output_dtype, 'dtype of to unmatches output' fluid_op = 'cast' @@ -804,8 +817,8 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs): """ - onnx::Concat-4: - """ + onnx::Concat-4: + """ # I/O val_concat_result, = outputs @@ -839,11 +852,11 @@ def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs): def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): """ - onnx::Constant-9: - """ + onnx::Constant-9: + """ # I/O - assert len(inputs) == 0 + assert len(inputs) == 0, 'constant op accept no inputs' val_output, = outputs var_output = _make_var_name(val_output) @@ -851,7 +864,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): value = attrs['value'] # required dtype = _np.dtype(value.dtype) output_dtype = _dtype_or_none(value_infos, val_output) - if output_dtype: + if output_dtype is not None: assert dtype == output_dtype, 'tensor dtype unmatches storage dtype' @@ -900,8 +913,8 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): """ - onnx::ConstantOfShape-9: - """ + onnx::ConstantOfShape-9: + """ # I/O val_shape, = inputs @@ -939,8 +952,8 @@ def Conv(prog, *args, **kwargs): """ - onnx::Conv-1: - """ + onnx::Conv-1: + """ # I/O val_x, val_w = inputs[:2] @@ -970,13 +983,16 @@ def Conv(prog, paddings, val_x = _pad_if_asymmetric(prog, pads, val_x, value_infos) var_x = _make_var_name(val_x) name_attr = ', name={}'.format(repr(name)) if name else '' + if embed_params: + embed_params = (_check_embeddable(value_infos, val_w) and not has_bias + or _check_embeddable(value_infos, val_b)) if embed_params: assert name != '' var_w = name + '.w_0' - value_infos[val_w].setdefault('embeded_as', []).append(var_w) + value_infos[val_w]['embeded_as'].append(var_w) if has_bias: var_b = name + '.b_0' - value_infos[val_b].setdefault('embeded_as', []).append(var_b) + value_infos[val_b]['embeded_as'].append(var_b) param_attr = '' else: param_attr = ', bias_attr=False' @@ -1046,8 +1062,8 @@ def ConvTranspose(prog, *args, **kwargs): """ - onnx::ConvTranspose-1: - """ + onnx::ConvTranspose-1: + """ # I/O val_x, val_w = inputs[:2] @@ -1080,13 +1096,16 @@ def ConvTranspose(prog, paddings, val_x = _pad_if_asymmetric(prog, pads, val_x, value_infos) var_x = _make_var_name(val_x) name_attr = ', name={}'.format(repr(name)) if name else '' + if embed_params: + embed_params = (_check_embeddable(value_infos, val_w) and not has_bias + or _check_embeddable(value_infos, val_b)) if embed_params: assert name != '' var_w = name + '.w_0' - value_infos[val_w].setdefault('embeded_as', []).append(var_w) + value_infos[val_w]['embeded_as'].append(var_w) if has_bias: var_b = name + '.b_0' - value_infos[val_b].setdefault('embeded_as', []).append(var_b) + value_infos[val_b]['embeded_as'].append(var_b) param_attr = '' else: param_attr = ', bias_attr=False' @@ -1167,8 +1186,8 @@ def ConvTranspose(prog, def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): """ - onnx::Gemm-9: - """ + onnx::Gemm-9: + """ # due to fluid fc don't support transposed weight, we use matmul + ew_add val_a, val_b, val_c = inputs @@ -1259,8 +1278,8 @@ def GlobalAveragePool(prog, *args, **kwargs): """ - onnx::GlobalAveragePool-1: - """ + onnx::GlobalAveragePool-1: + """ return _global_pool(prog, 'avg', @@ -1280,8 +1299,8 @@ def GlobalMaxPool(prog, *args, **kwargs): """ - onnx::GlobalMaxPool-1: - """ + onnx::GlobalMaxPool-1: + """ return _global_pool(prog, 'max', @@ -1295,8 +1314,8 @@ def GlobalMaxPool(prog, def MaxPool(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): """ - onnx::MaxPool-10: - """ + onnx::MaxPool-10: + """ return _pool(prog, 'max', inputs, outputs, attrs, value_infos, name=name) @@ -1304,16 +1323,16 @@ def MaxPool(prog, inputs, outputs, attrs, value_infos, name='', *args, def MaxRoiPool(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): """ - onnx::MaxRoiPool-1: - """ + onnx::MaxRoiPool-1: + """ _roi_pool(prog, 'roi_pool', inputs, outputs, attrs, value_infos, name) def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): """ - onnx::Pad-2: - """ + onnx::Pad-2: + """ # I/O val_data, = inputs @@ -1330,9 +1349,9 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): assume_pad2d = False if len(pads) == 4: assume_pad2d |= mode != 'constant' - if data_shape: + if data_shape is not None: assume_pad2d |= data_shape and len(data_shape) == 4 # NCHW - if output_shape: + if output_shape is not None: assume_pad2d |= output_shape and len(output_shape) == 4 # NCHW od_attrs = {'pad_value': value} if assume_pad2d: @@ -1383,8 +1402,8 @@ def PRelu(prog, *args, **kwargs): """ - onnx::PRelu-9: - """ + onnx::PRelu-9: + """ # I/O val_x, val_slope = inputs @@ -1404,10 +1423,12 @@ def PRelu(prog, mode = 'element' fluid_op = 'prelu' name_attr = ', name={}'.format(repr(name)) if name else '' + if embed_params: + embed_params = _check_embeddable(value_infos, val_slope) if embed_params: assert name != '' var_slope = name + '.w_0' - value_infos[val_slope].setdefault('embeded_as', []).append(var_slope) + value_infos[val_slope]['embeded_as'].append(var_slope) param_attr = '' else: var_slope = _make_var_name(val_slope) @@ -1436,16 +1457,16 @@ def PRelu(prog, def PsRoiPool(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): """ - caffe2::PsRoiPool - """ + caffe2::PsRoiPool + """ _roi_pool(prog, 'psroi_pool', inputs, outputs, attrs, value_infos, name) def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): """ - onnx::Reshape-5: - """ + onnx::Reshape-5: + """ # I/O val_data, val_shape = inputs @@ -1474,6 +1495,8 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): name_attr = ', name={}'.format(repr(name)) if name else '' # generation + val_shape_int32 = val_shape + '_int32' # explicit variable + var_shape_int32 = _make_var_name(val_shape_int32) prog.Code('# shape:{}={} # const as literal'.format(var_shape, shape)) if is_const_shape: prog.Code('{} = layers.{}({}' @@ -1487,8 +1510,6 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): name_attr, )) else: - val_shape_int32 = val_shape + '_int32' # explicit variable - var_shape_int32 = _make_var_name(val_shape_int32) prog.Op( '', 'Cast', @@ -1514,34 +1535,26 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): var_xshape = name + '.xshape' # dummy output prog.VarDesc(var_reshaped) prog.VarDesc(var_xshape) - if is_const_shape: - prog.OpDesc( - fluid_op, - ([var_data], 'X'), - ([var_reshaped, var_xshape], 'Out', 'XShape'), - {'shape': shape}, - ) - else: - prog.OpDesc( - fluid_op, - ([var_data, var_shape_int32], 'X', 'Shape'), - ([var_reshaped, var_xshape], 'Out', 'XShape'), - {'shape': shape}, - ) + prog.OpDesc( + fluid_op, + ([var_data, var_shape_int32], 'X', 'Shape'), + ([var_reshaped, var_xshape], 'Out', 'XShape'), + {'shape': shape}, + ) def Resize(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): """ - onnx::Resize-10: - """ + onnx::Resize-10: + """ return _interpolate(prog, inputs, outputs, attrs, value_infos, name=name) def RoiAlign(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): """ - caffe2::RoiAlign - """ + caffe2::RoiAlign + """ _roi_pool(prog, 'roi_align', inputs, outputs, attrs, value_infos, name) @@ -1580,8 +1593,8 @@ def RoiAlign(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): """ - onnx::Slice-1:9 - """ + onnx::Slice-1:9 + """ # I/O val_data, = inputs @@ -1595,7 +1608,7 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): starts = attrs['starts'] # required ends = attrs['ends'] # required shape = _shape_or_none(value_infos, val_data) - if shape: + if shape is not None: # ndims = len(shape) # for idx, value in enumerate(axes): # if value > ONNX_INT_MAX // 2: @@ -1639,8 +1652,8 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs): """ - onnx::Split-2: - """ + onnx::Split-2: + """ # I/O val_input, = inputs @@ -1680,8 +1693,8 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs): def Sum(prog, inputs, outputs, *args, **kwargs): """ - onnx::Sum-8: - """ + onnx::Sum-8: + """ # I/O val_sum, = outputs @@ -1710,8 +1723,8 @@ def Sum(prog, inputs, outputs, *args, **kwargs): def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): """ - onnx::Tile-1: - """ + onnx::Tile-1: + """ # I/O val_input, val_repeats = inputs @@ -1749,8 +1762,8 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): def Transpose(prog, inputs, outputs, attrs, *args, name='', **kwargs): """ - onnx::Transpose-1: - """ + onnx::Transpose-1: + """ # I/O val_data, = inputs @@ -1795,8 +1808,8 @@ def Upsample(prog, *args, **kwargs): """ - onnx::Upsample-9:9 - """ + onnx::Upsample-9:9 + """ return _interpolate(prog, inputs, outputs, attrs, value_infos, name=name) diff --git a/onnx2fluid/onnx2fluid/torch_export_helper.py b/onnx2fluid/onnx2fluid/torch_export_helper.py index 7a667b65d8bf3e22557a25b2e3162658d000801c..7a0fd6031433e989fafc142a40e9bad46df9f41f 100644 --- a/onnx2fluid/onnx2fluid/torch_export_helper.py +++ b/onnx2fluid/onnx2fluid/torch_export_helper.py @@ -25,7 +25,8 @@ def ensure_tuple(obj): def flatten_list(obj, out=None): - assert isinstance(obj, list) + assert isinstance(obj, list), 'list type required' + if out is None: out = type(obj)() for item in obj: @@ -38,11 +39,11 @@ def flatten_list(obj, out=None): def export_data(state_dict, prefix=''): """ - export binary data with meta text for raw C++ inference engines - """ + export binary data with meta text for raw C++ inference engines + """ def str_(obj): - if isinstance(obj, (tuple, list)): + if isinstance(obj, (tuple, list, set)): return str(obj)[1:-1].replace(' ', '') return str(obj) @@ -72,8 +73,8 @@ def export_onnx_with_validation(model, *args, **kwargs): """ - export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file - """ + export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file + """ is_tuple_or_list = lambda x: isinstance(x, (tuple, list)) diff --git a/onnx2fluid/onnx2fluid/validation.py b/onnx2fluid/onnx2fluid/validation.py index dc679df05093b0e9f21c2ea15eca07e308033892..223e1116fdd652c525a6f713e9d813eef7478c1c 100644 --- a/onnx2fluid/onnx2fluid/validation.py +++ b/onnx2fluid/onnx2fluid/validation.py @@ -10,14 +10,15 @@ import importlib, logging, os, sys def flatten_dict(obj, out=None): - assert isinstance(obj, dict) + assert isinstance(obj, dict), 'dict type required' + if out is None: out = type(obj)() for key, value in obj.items(): if isinstance(value, dict): flatten_dict(value, out) else: - assert key not in out + assert key not in out, 'key conflicted' out[key] = value return out @@ -29,15 +30,16 @@ def ensure_list(obj): def validate(fluid_model_filename, - golden_data_filename, - model_func_name='inference', + golden_data_filename='', atol=1e-3, rtol=1e-3, + model_func_name='inference', save_inference_model=False, + inference_input_names=None, **kwargs): """ - inference the converted Paddle fluid model, validate with given golden data - """ + inference the converted Paddle fluid model, validate with given golden data + """ import numpy as np import paddle.fluid as fluid @@ -86,24 +88,50 @@ def validate(fluid_model_filename, raise ValueError('unsupported Paddle fluid model filename') # load data - logger.info('using golden data %s', golden_data_filename) - if golden_data_filename.endswith('.npz'): - test_data = np.load(golden_data_filename, encoding='bytes') - input_data = test_data['inputs'].tolist() - output_data = test_data['outputs'].tolist() + if golden_data_filename: + logger.info('using golden data %s', golden_data_filename) + if golden_data_filename.endswith('.npz'): + test_data = np.load(golden_data_filename, encoding='bytes') + input_data = test_data['inputs'].tolist() + output_data = test_data['outputs'].tolist() + else: + test_data = np.load(golden_data_filename, encoding='bytes').tolist() + input_data = test_data['inputs'] + output_data = test_data['outputs'] + input_data = flatten_dict(input_data) + output_data = flatten_dict(output_data) + input_names = input_data.keys() + logger.info('found %d I/O golden data, starting test ...', + len(input_data) + len(output_data)) else: - test_data = np.load(golden_data_filename, encoding='bytes').tolist() - input_data = test_data['inputs'] - output_data = test_data['outputs'] - input_data = flatten_dict(input_data) - output_data = flatten_dict(output_data) - logger.info('found %d I/O golden data, starting test ...', - len(input_data) + len(output_data)) - - # DEBUG: reload test for Python code - if basename.endswith('.py') and save_inference_model: + assert inference_input_names, 'input names required for type-shape inference' + + input_names = inference_input_names + logger.info('using input names: %s', ', '.join(input_names)) + + # type-shape inference and re-save + if save_inference_model: + for block in prog.blocks: + block_desc = block.desc + for idx_op in range(block_desc.op_size()): + op_desc = block_desc.op(idx_op) + if op_desc.type() in ('feed', 'fetch'): + continue + op_desc.infer_var_type(block_desc) + op_desc.infer_shape(block_desc) + for var_name, var in block.vars.items(): + var_desc = var.desc + if var_desc.type() != fluid.core.VarDesc.VarType.LOD_TENSOR: + continue + # WORKAROUND: dirty way to give dtype to partial-infered vars + # which could not be cleared! + try: + var.to_string(True) + except ValueError: + var_desc.set_dtype(fluid.core.VarDesc.VarType.FP32) + fluid.io.save_inference_model(fluid_model_dir, - input_data.keys(), + input_names, var_outs, exe, main_program=prog, @@ -112,8 +140,12 @@ def validate(fluid_model_filename, fluid.io.load_inference_model(fluid_model_dir, exe) logger.info('model re-load passed') + if not golden_data_filename: + return True + # execute - outputs = exe.run(prog, feed=input_data, fetch_list=out_names) + outputs = exe.run(prog, feed=input_data, + fetch_list=out_names) # out_names can be vars logger.info('execution passed') # validate @@ -134,11 +166,10 @@ def validate(fluid_model_filename, logger.info('accuracy passed') else: logger.info('accuracy not passed') - return passed -if __name__ == '__main__': +def main(): import argparse parser = argparse.ArgumentParser( @@ -160,6 +191,7 @@ if __name__ == '__main__': '--test_data', '-t', type=str, + default='', help='I/O golden data for validation, e.g. test.npy, test.npz', ) parser.add_argument( @@ -175,19 +207,36 @@ if __name__ == '__main__': default=1e-2, help='assertion relative tolerance for validation', ) + parser.add_argument( + '--infer_inputs', + '-i', + nargs='?', + default=None, + const='', + help= + 'perform type-shape inference with given input names and re-save model', + ) args = parser.parse_args() logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s' logging_level = logging.DEBUG if args.debug else logging.INFO logging.basicConfig(format=logging_format, level=logging_level) - debug = args.debug + # debug = args.debug fluid_model_filename = args.model[0] golden_data_filename = args.test_data atol, rtol = args.atol, args.rtol + save_inference_model = args.infer_inputs is not None + inference_input_names = args.infer_inputs.split( + ',') if args.infer_inputs else None validate(fluid_model_filename, - golden_data_filename, + golden_data_filename=golden_data_filename, atol=atol, rtol=rtol, - save_inference_model=debug) + save_inference_model=save_inference_model, + inference_input_names=inference_input_names) + + +if __name__ == '__main__': + main() diff --git a/onnx2fluid/onnx2fluid/writer.py b/onnx2fluid/onnx2fluid/writer.py index c6aba4b61feaa5d7d4f045748706c008643245b0..70bd72765e067538230c5e5644d994af8d902099 100644 --- a/onnx2fluid/onnx2fluid/writer.py +++ b/onnx2fluid/onnx2fluid/writer.py @@ -44,6 +44,8 @@ def irepr(obj, to='_'): def flatten_list(obj, out=None): + assert isinstance(obj, list), 'list type required' + if out is None: out = type(obj)() for item in obj: @@ -56,12 +58,12 @@ def flatten_list(obj, out=None): def make_attr_name(name): """ - make a valid code name for ParamAttr - """ + make a valid code name for ParamAttr + """ - if name == '': - raise ValueError('name should not be empty') - for s in ' \\|/:': # + assert name != '', 'name should not be empty' + + for s in ' \\|/:-': # name = name.replace(s, '_') if not name.startswith('_'): name = '_' + name @@ -70,8 +72,8 @@ def make_attr_name(name): class Program(object): """ - fluid Python code and ProgramDesc wrapper - """ + fluid Python code and ProgramDesc wrapper + """ DTYPE_TO_FRAMEWORK_DTYPE = { 'bool': framework_pb2.VarType.BOOL, @@ -88,8 +90,8 @@ class Program(object): @staticmethod def Dtype(dtype): """ - convert dtype to fulid framework dtype - """ + convert dtype to fulid framework dtype + """ dtype = np.dtype(dtype).name return Program.DTYPE_TO_FRAMEWORK_DTYPE[dtype] @@ -97,8 +99,8 @@ class Program(object): @staticmethod def OpDescVars(vals, *keys): """ - make (OpDesc.Var)s - """ + make (OpDesc.Var)s + """ od_vars = [] for idx, key in enumerate(keys): @@ -112,8 +114,8 @@ class Program(object): @staticmethod def OpDescAttrs(attrs): """ - make (OpDesc.Attr)s - """ + make (OpDesc.Attr)s + """ od_attrs = [] for key, value in attrs.items(): @@ -178,8 +180,8 @@ class Program(object): def Code(self, code): """ - add Python code - """ + add Python code + """ if self.code_mutable: self.codes.append(code) @@ -190,16 +192,16 @@ class Program(object): output_val_keys=None, attrs=None): """ - add OpDesc - """ + add OpDesc + """ desc = framework_pb2.OpDesc() desc.type = op_type - if input_val_keys is not None: + if input_val_keys: desc.inputs.extend(self.OpDescVars(*input_val_keys)) - if output_val_keys is not None: + if output_val_keys: desc.outputs.extend(self.OpDescVars(*output_val_keys)) - if attrs is not None: + if attrs: desc.attrs.extend(self.OpDescAttrs(attrs)) self.op_descs.append(desc) return desc @@ -210,8 +212,8 @@ class Program(object): value_info=None, remove_batch=None): """ - add VarDesc, - """ + add VarDesc, + """ assert var_name not in self.var_descs, 'var naming conflicted' @@ -220,13 +222,16 @@ class Program(object): var_desc.persistable = persistable var_desc.type.type = framework_pb2.VarType.LOD_TENSOR self.var_descs[var_name] = var_desc + if value_info: - self.VarTypeInfo(var_name, value_info, remove_batch=remove_batch) + self.VarTypeShapeInfo(var_name, + value_info, + remove_batch=remove_batch) def Op(self, domain, op_type, *args, **kwargs): """ - convert an ONNX op and add it to program - """ + convert an ONNX op and add it to program + """ if domain != '': # TODO: symbolic file routing by domain raise ValueError('only default domain supported') @@ -242,8 +247,8 @@ class Program(object): def IntermediateOp(self, domain, op_type, *args, **kwargs): """ - convert an intermediate ONNX op declaring in desc program only - """ + convert an intermediate ONNX op declaring in desc program only + """ code_mutable = self.code_mutable self.code_mutable = False @@ -255,10 +260,10 @@ class Program(object): else: self.code_mutable = code_mutable - def VarTypeInfo(self, var_name, value_info, remove_batch=None): + def VarTypeShapeInfo(self, var_name, value_info, remove_batch=None): + """ + set value_info for var """ - set value_info for var - """ if var_name not in self.var_descs: return @@ -284,8 +289,8 @@ class Program(object): class Writer(object): """ - fluid code and desc writter - """ + fluid code and desc writter + """ # CODE_INDENT = ' ' * 4 CODE_INDENT = '\t' @@ -293,8 +298,8 @@ class Writer(object): @staticmethod def header_code(func_name, info=''): """ - Python header codes - """ + Python header codes + """ codes = [] codes.append('"""') @@ -315,8 +320,8 @@ class Writer(object): def emit_op(prog, name, domain, op_type, inputs, outputs, attrs, value_infos, *args, **kwargs): """ - emit an ONNX op into program - """ + emit an ONNX op into program + """ prog.Code('# {}, {}::{}: {} -> {}, {}'.format(name, domain, op_type, inputs, outputs, @@ -334,8 +339,8 @@ class Writer(object): @staticmethod def emit_param(prog, name, value_info): """ - emit an ONNX weight into program - """ + emit an ONNX weight into program + """ if value_info.get('embeded_as', []): var_names = value_info['embeded_as'] @@ -359,8 +364,8 @@ class Writer(object): @staticmethod def emit_inputs(prog, names, value_infos, remove_batch=None): """ - emit ONNX inputs into program - """ + emit ONNX inputs into program + """ for idx, name in enumerate(names): var_name = make_var_name(name) @@ -396,8 +401,8 @@ class Writer(object): @staticmethod def emit_outputs(prog, names): #, value_infos """ - emit ONNX outputs into program - """ + emit ONNX outputs into program + """ code = 'return ' for idx, name in enumerate(names): @@ -416,8 +421,8 @@ class Writer(object): @staticmethod def add_codes(codes, others, indent): """ - flatten codes in program - """ + flatten codes in program + """ for code in flatten_list(others): codes.append(Writer.CODE_INDENT * indent + code) @@ -426,11 +431,10 @@ class Writer(object): @staticmethod def write_weight(weight, filename): """ - write single weight in fluid desc - """ + write single weight in fluid desc + """ - if not isinstance(weight, np.ndarray): - raise TypeError('weight is not an ndarray') + assert isinstance(weight, np.ndarray), 'weight is not an ndarray' tensor_desc = framework_pb2.VarType.TensorDesc() tensor_desc.data_type = Program.Dtype(weight.dtype) @@ -448,12 +452,11 @@ class Writer(object): @staticmethod def write_weights(weights, save_dir): """ - write multiple weights in each fluid desc - """ + write multiple weights in each fluid desc + """ for name, weight in weights.items(): - if not isinstance(weights, dict): - raise TypeError('dict type weights required') + assert isinstance(weights, dict), 'dict type weights required' var_name = make_var_name(name) filename = os.path.join(save_dir, var_name) @@ -463,8 +466,8 @@ class Writer(object): @staticmethod def write_code_file(filename, header_code, *body_codes): """ - write Python code to file - """ + write Python code to file + """ codes = [] Writer.add_codes(codes, header_code, 0) @@ -481,8 +484,8 @@ class Writer(object): @staticmethod def write_desc_file(filename, op_descs, var_descs): """ - write desc program to file - """ + write desc program to file + """ prog_desc = framework_pb2.ProgramDesc() block_desc = prog_desc.blocks.add() diff --git a/onnx2fluid/setup.cfg b/onnx2fluid/setup.cfg index 96f0f9b2a28ad6cc5302debc7aac4dd8c79928ea..bf59c1fe961108f01413d397a1d90a9bf0bd986b 100644 --- a/onnx2fluid/setup.cfg +++ b/onnx2fluid/setup.cfg @@ -54,8 +54,8 @@ zip_safe = True [options.entry_points] console_scripts = onnx2fluid = onnx2fluid.__main__ - onnx2fluid_convert = onnx2fluid.conversion - onnx2fluid_validate = onnx2fluid.validation + onnx2fluid_convert = onnx2fluid.conversion:main + onnx2fluid_validate = onnx2fluid.validation:main # 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下 # 仅支持文件,不支持目录,但可以使用通配