diff --git a/README.md b/README.md index d7192697de6eacc5e22ac704adf55fe38f48f60f..c3ee6e3e766e6e92ad7978d305fce77f88b939ac 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ paddlepaddle >= 1.5.0 **以下依赖只需对应安装自己需要的即可** 转换tensorflow模型 : tensorflow == 1.14.0 转换caffe模型 : caffe == 1.0.0 - +转换onnx模型 : onnx == 1.5.0 pytorch == 1.1.0 ## 安装 ``` pip install x2paddle @@ -32,8 +32,9 @@ x2paddle --framework=tensorflow --model=tf_model.pb --save_dir=pd_model x2paddle --framework=caffe --prototxt=deploy.proto --weight=deploy.caffemodel --save_dir=pd_model ``` ### ONNX -即将release,目前仍可使用[onnx2fluid](https://github.com/PaddlePaddle/X2Paddle/tree/release-0.3/onnx2fluid) - +``` +x2paddle --framework=onnx --model=onnx_model.onnx --save_dir=pd_model +``` ### 参数选项 | 参数 | | |----------|--------------| @@ -42,7 +43,7 @@ x2paddle --framework=caffe --prototxt=deploy.proto --weight=deploy.caffemodel -- |--weight | 当framework为caffe时,该参数指定caffe模型的参数文件路径 | |--save_dir | 指定转换后的模型保存目录路径 | |--model | 当framework为tensorflow时,该参数指定tensorflow的pb模型文件路径 | -|--caffe_proto | [可选]由caffe.proto编译成caffe_pb2.py文件的存放路径,当没有安装caffe或者使用自定义Layer时使用,默认为None | +|--caffe_proto | [可选]由caffe.proto编译成caffe_pb2.py文件的存放路径,当存在自定义Layer时使用,默认为None | ## 使用转换后的模型 转换后的模型包括`model_with_code`和`inference_model`两个目录。 diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 8c74504041febc6d2abfd67f0dc8f131f0087b74..2edbe391c803914253e3356e673d8c60f12ad59f 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -15,7 +15,6 @@ from six import text_type as _text_type import argparse import sys -import x2paddle def arg_parser(): @@ -104,9 +103,32 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto): mapper.save_inference_model(save_dir) +def onnx2paddle(model_path, save_dir): + # check onnx installation and version + try: + import onnx + version = onnx.version.version + if version != '1.5.0': + print("onnx==1.5.0 is required") + return + except: + print("onnx is not installed, use \"pip install onnx==1.5.0\".") + return + + from x2paddle.decoder.onnx_decoder import ONNXDecoder + from x2paddle.op_mapper.onnx_op_mapper import ONNXOpMapper + from x2paddle.optimizer.onnx_optimizer import ONNXOptimizer + print("Now translating model from onnx to paddle.") + model = ONNXDecoder(model_path) + mapper = ONNXOpMapper(model) + optimizer = ONNXOptimizer(mapper) + optimizer.delete_redundance_code() + mapper.save_inference_model(save_dir) + + def main(): if len(sys.argv) < 2: - print("Use \"x2paddle -h\" to print the help information\n") + print("Use \"x2paddle -h\" to print the help information") return parser = arg_parser() @@ -124,7 +146,6 @@ def main(): return except: print("paddlepaddle not installed, use \"pip install paddlepaddle\"") - assert args.framework is not None, "--from is not defined(tensorflow/caffe)" assert args.save_dir is not None, "--save_dir is not defined" @@ -136,9 +157,11 @@ def main(): assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model" caffe2paddle(args.prototxt, args.weight, args.save_dir, args.caffe_proto) - + elif args.framework == "onnx": + assert args.model is not None, "--model should be defined while translating onnx model" + onnx2paddle(args.model, args.save_dir) else: - raise Exception("--framework only support tensorflow/caffe now") + raise Exception("--framework only support tensorflow/caffe/onnx now") if __name__ == "__main__": diff --git a/x2paddle/decoder/onnx_decoder.py b/x2paddle/decoder/onnx_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5ecd3c6e7787df71e2fd557423dcc10d5558ed18 --- /dev/null +++ b/x2paddle/decoder/onnx_decoder.py @@ -0,0 +1,497 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from x2paddle.core.graph import GraphNode, Graph +from x2paddle.core.fluid_code import FluidCode +from onnx.checker import ValidationError +from onnx.checker import check_model +from onnx.utils import polish_model +from onnx.version_converter import convert_version +from onnx import helper +from onnx.helper import get_attribute_value, make_attribute +from onnx.shape_inference import infer_shapes +from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE +from onnx.numpy_helper import to_array +from collections import OrderedDict as Dict +import onnx +import numpy as np +from copy import deepcopy +import logging as _logging + +default_op_domain = 'ai.onnx' +_logger = _logging.getLogger(__name__) + + +class ONNXGraphNode(GraphNode): + def __init__(self, layer, layer_name=None): + if layer_name is None: + super(ONNXGraphNode, self).__init__(layer, layer.name) + else: + super(ONNXGraphNode, self).__init__(layer, layer_name) + self.layer_type = layer.op_type + self.fluid_code = FluidCode() + self.attr_map = self.get_attr_map() + self.dtype_map = {1: "float32", 3: "int32", 9: "int64"} + self.weight_inputs = list() + self.out_shapes = None + self.dtype = None + + def get_attr_map(self): + """ + convert ONNX node attributes to dict + """ + return { + attr.name: self.get_attribute_value2(attr) + for attr in self.layer.attribute + } + + @property + def value(self): + assert 'Constant' in self.layer_type, "Only Constant node has value." + + attr = self.layer.attr['value'] + if 'value' in self.attr_map: + return default + return self.attr_map[name] + + def get_attribute_value2(self, attr): + """ + get_attribute_value enhanced + """ + if attr.type == onnx.AttributeProto.TENSOR: + dtype = np.dtype(TENSOR_TYPE_TO_NP_TYPE[attr.t.data_type]) + data = attr.t.raw_data + value = np.frombuffer(data, + dtype=dtype, + count=(len(data) // dtype.itemsize)) + elif attr.type == onnx.AttributeProto.STRING: + value = attr.s + value = value.decode() if isinstance(value, bytes) else value + else: + value = get_attribute_value(attr) + return value + + def get_attr(self, name, default=None): + """ + get_attribute_value from attr_map + """ + if name not in self.attr_map: + return default + return self.attr_map[name] + + +class ONNXGraphDataNode(GraphNode): + def __init__(self, layer, layer_name=None, is_global_input=False): + if layer_name is None: + super(ONNXGraphDataNode, self).__init__(layer, layer.name) + else: + super(ONNXGraphDataNode, self).__init__(layer, layer_name) + if is_global_input: + self.layer_type = 'place_holder' + else: + self.layer_type = 'create_parameter' + self.layer_name = layer_name + self.fluid_code = FluidCode() + self.weight = None + self.embeded_as = None + + @property + def out_shapes(self): + values = self.layer.type.tensor_type.shape.dim + out_shapes = list() + out_shapes = [dim.dim_value for dim in values] + return out_shapes + + @property + def dtype(self): + dtype = self.layer.type.tensor_type.elem_type + + return TENSOR_TYPE_TO_NP_TYPE[dtype] + + +class ONNXGraph(Graph): + def __init__(self, model): + super(ONNXGraph, self).__init__(model) + self.initializer = {} + self.place_holder_nodes = list() + self.get_place_holder_nodes() + + def get_inner_nodes(self): + """ + generate inner node of ONNX model + """ + inner_nodes = [] + if not isinstance(self.model, onnx.GraphProto): + logger.error('graph is not a GraphProto instance') + return + for initializer in self.model.initializer: + name = initializer.name + inner_nodes.append(name) + return inner_nodes + + def get_place_holder_nodes(self): + """ + generate place_holder node of ONNX model + """ + inner_nodes = self.get_inner_nodes() + input_nodes = [value.name for value in self.model.input] + for ipt_data in input_nodes: + if ipt_data not in inner_nodes: + self.place_holder_nodes.append(ipt_data) + + def is_place_holder_nodes(self, layer): + """ + return layer is or not place_holder node + """ + if layer in self.place_holder_nodes: + return True + return False + + def build(self): + """ + build topo_sort of ONNX model + """ + for layer in self.model.node: + self.node_map[layer.name] = ONNXGraphNode(layer) + + #set op node's dtype and out_shapes + for item in self.model.value_info: + if item.name in self.node_map: + self.node_map[item.name].dtype = TENSOR_TYPE_TO_NP_TYPE[ + item.type.tensor_type.elem_type] + self.node_map[item.name].out_shapes = [ + dim.dim_value for dim in item.type.tensor_type.shape.dim + ] + + for layer in self.model.input: + if layer.name not in self.node_map: + is_place_holder = self.is_place_holder_nodes(layer.name) + self.node_map[layer.name] = ONNXGraphDataNode( + layer, + layer_name=layer.name, + is_global_input=is_place_holder) + #set data node's weight + for name, weight in self.graph_weights(self.model): + if name in self.node_map: + if isinstance(self.node_map[name], ONNXGraphDataNode): + self.node_map[name].weight = weight + self.node_map[name].embeded_as = [] + + #generate connection between nodes for topo + for layer_name, node in self.node_map.items(): + if isinstance(node, ONNXGraphNode): + for idx, in_node in enumerate(node.layer.input): + if in_node not in self.node_map: + raise Exception( + 'input[{}] of node[{}] does not exist in node_map'. + format(in_node, layer_name)) + else: + self.connect(in_node, layer_name) + + #generate topo + super(ONNXGraph, self).build() + + self.input_nodes = self.place_holder_nodes + + def get_nodes(self, names, copy=False): + """ + get nodes by more than one name + """ + nodes = [] + for name in names: + nodes.add(self.get_node(name, copy=copy)) + + def graph_weights(self, graph): + """ + generator for weights + """ + + if not isinstance(graph, onnx.GraphProto): + logger.error('graph is not a GraphProto instance') + return + + for initializer in graph.initializer: + name = initializer.name + weight = to_array(initializer) + yield name, weight + + +class ONNXDecoder(object): + def __init__(self, onnx_model): + model = onnx.load(onnx_model) + print('model ir_version: {}, op version: {}'.format( + model.ir_version, model.opset_import[0].version)) + + if model.opset_import[0].version < 9: + _logger.warning( + 'Now, onnx2paddle main support convert onnx model opset_verison == 9,' + 'opset_verison of your onnx model is %d < 9,' + 'some operator may cannot convert.', + model.opset_import[0].version) + check_model(model) + + model = polish_model(model) + + model = self.optimize_model_skip_op_for_inference(model) + model = self.optimize_model_strip_initializer(model) + self.standardize_variable_name(model.graph) + + self.model = model + graph_def = model.graph + + self.onnx_graph = ONNXGraph(graph_def) + self.onnx_graph.build() + + def build_value_refs(self, nodes): + """ + build op reference of inputs and outputs + """ + input_refs = Dict() + output_refs = Dict() + for idx, node in enumerate(nodes): + for val_name in node.input: + input_refs.setdefault(val_name, set()).add(idx) + for val_name in node.output: + output_refs.setdefault(val_name, set()).add(idx) + return input_refs, output_refs + + def skip_node_forward(self, nodes, src_output_name, dst_input_name, + input_refs): + """ + skip nodes between src_output_name -> dst_input_name and connect this pair + """ + processed = 0 + for next_idx in input_refs[src_output_name]: + next_node = nodes[next_idx] + for val_idx, next_input_name in enumerate(next_node.input): + if next_input_name == src_output_name: + next_node.input[val_idx] = dst_input_name + processed += 1 + return processed + + def skip_node_backward(self, nodes, src_input_name, dst_output_name, + output_refs): + """ + skip nodes between dst_output_name -> src_input_name and connect this pair + """ + processed = 0 + for prev_idx in output_refs[src_input_name]: + prev_node = nodes[prev_idx] + for val_idx, prev_output_name in enumerate(prev_node.output): + if prev_output_name == src_input_name: + prev_node.output[val_idx] = dst_output_name + processed += 1 + return processed + + def optimize_model_skip_op_for_inference(self, model, op_list=None): + """ + skip ops can be bypassed for inference + """ + if op_list is None: + op_list = ['Dropout'] + + nodes = model.graph.node + input_refs, output_refs = self.build_value_refs(nodes) + ret = type(model)() + ret.CopyFrom(model) + ret_nodes = ret.graph.node + nodes_to_remove = [] + for node_idx, node in enumerate(nodes): + if not (node.domain == default_op_domain or node.domain == ''): + continue + op_type = node.op_type + if not (op_type in op_list): + continue + if op_type in ['Dropout']: + input_name = node.input[0] + output_name = node.output[0] + elif not (len(node.input) == 1 and len(node.output) == 1): + print( + 'currently only 1-input-1-output op supported, skip required %d: %s', + node_idx, node.op_type) + continue + else: + input_name = node.input[0] + output_name = node.output[0] + + if output_name in input_refs: + processed = self.skip_node_forward(ret_nodes, output_name, + input_name, input_refs) + elif input_name in output_refs: + processed = self.skip_node_backward(ret_nodes, input_name, + output_name, output_refs) + else: + processed = -1 + + if processed > 0: + nodes_to_remove.append(node_idx) + print('skip op {}: {} -> {} -> {}'.format( + node_idx, input_name, node.op_type, output_name)) + elif processed == 0: + print('weird, no node processed') + else: + print('standalone op {}: {} -> {} -> {} not skipped'.format( + node_idx, input_name, node.op_type, output_name)) + + nodes_to_remove.sort(reverse=True) + for node_idx in nodes_to_remove: + ret_nodes.pop(node_idx) + return ret + + def optimize_model_strip_initializer(self, model, keep_input_only=True): + """ + strip weights for inference + """ + nodes = model.graph.node + input_refs, output_refs = self.build_value_refs(nodes) + out_names = [val.name for val in model.graph.output] + + ret = type(model)() + ret.CopyFrom(model) + # strip initializers + ret.graph.ClearField('initializer') + ret_initializers = ret.graph.initializer + for initializer in model.graph.initializer: + name = initializer.name + if name in input_refs: + ret_initializers.add().CopyFrom(initializer) + elif not keep_input_only and name in output_refs: + ret_initializers.add().CopyFrom(initializer) + else: + dtype = TENSOR_TYPE_TO_NP_TYPE[initializer.data_type] + + # strip inputs + ret.graph.ClearField('input') + ret_inputs = ret.graph.input + for item in model.graph.input: + name = item.name + if name in input_refs or name in out_names: + ret_inputs.add().CopyFrom(item) + return ret + + def make_variable_name(self, name): + """ + make a valid code name for ParamAttr + """ + + if name == '': + raise ValueError('name should not be empty') + for s in ' .*?\\/-:': # + name = name.replace(s, '_') + return '_' + name + + def standardize_variable_name(self, graph): + """ + standardize variable name for paddle's code + """ + + for initializer in graph.initializer: + initializer.name = self.make_variable_name(initializer.name) + for ipt in graph.input: + ipt.name = self.make_variable_name(ipt.name) + for output in graph.output: + output.name = self.make_variable_name(output.name) + for item in graph.value_info: + item.name = self.make_variable_name(item.name) + for node in graph.node: + if node.name == '': + node.name = node.output[0] + node.name = self.make_variable_name(node.name) + for i in range(len(node.input)): + node.input[i] = self.make_variable_name(node.input[i]) + for i in range(len(node.output)): + node.output[i] = self.make_variable_name(node.output[i]) + + def split_model(self, model, outputs=None): + """ + Takes a model and changes its outputs. + """ + if outputs is None: + raise RuntimeError("outputs is None") + if outputs == model.graph.output[0].name: + return model + nodes = model.graph.node + keep_nodes = [] + + # all the nodes we need to keep. + for node in nodes: + if outputs in node.output: + keep_nodes.append(node) + break + keep_nodes.append(node) + + infer_shapes = onnx.shape_inference.infer_shapes(model) + + var_out = [] + for value_info in infer_shapes.graph.value_info: + if value_info.name == outputs: + var_out.append(value_info) + break + + graph = helper.make_graph(keep_nodes, model.graph.name, + model.graph.input, var_out, + model.graph.initializer) + + onnx_model = helper.make_model(graph) + onnx_model.ir_version = model.ir_version + onnx_model.producer_name = model.producer_name + onnx_model.producer_version = model.producer_version + onnx_model.domain = model.domain + onnx_model.model_version = model.model_version + onnx_model.doc_string = model.doc_string + + if len(onnx_model.graph.input) != len(model.graph.input): + raise RuntimeError("Input mismatch {} != {}".format( + len(onnx_model.input), len(model.input))) + return onnx_model + + def get_dynamic_shape_from_caffe2(self, layer, input_shapes): + """ + get dynamic shape from caffe2.backend + """ + try: + import torch + version = torch.__version__ + if '1.1.0' not in version: + print("your model have dynamic graph, torch==1.1.0 is required") + return + except: + print( + "your model have dynamic graph, we use caff2 to inference graph, please use \"pip install torch==1.1.0\"." + ) + return + from caffe2.python.onnx.backend import prepare + shape = input_shapes[0] + np_images = np.random.rand(shape[0], shape[1], shape[2], + shape[3]).astype('float32') + num_onnx = self.split_model(self.model, layer) + prepared_backend = prepare(num_onnx, device='CPU') + output = prepared_backend.run(inputs=np_images) + return output[0].tolist() + + def get_dynamic_shape_from_onnx(self, layer, input_shapes): + """ + get dynamic shape from onnxruntime + """ + import onnxruntime as rt + from onnxruntime.backend import prepare + import numpy as np + num_onnx = self.split_model(self.model, layer) + sess = prepare(num_onnx) + shape = input_shapes[0] + print(shape) + np_images = np.random.rand(shape[0], shape[1], shape[2], + shape[3]).astype('float32') + output = sess.run(model=sess, inputs=np_images) + return output[0].tolist() diff --git a/x2paddle/op_mapper/onnx_directly_map.py b/x2paddle/op_mapper/onnx_directly_map.py new file mode 100644 index 0000000000000000000000000000000000000000..75c2a842f8074761b6a1a4cbacd3248fb8254526 --- /dev/null +++ b/x2paddle/op_mapper/onnx_directly_map.py @@ -0,0 +1,54 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict as _dict +import numpy as _np + +default_op_mapping_field_values = _dict() +default_op_mapping_field_values['FLUID_OP'] = '' +default_op_mapping_field_values['FLUID_INPUT_ARGS'] = None +default_op_mapping_field_values['FLUID_OUTPUT_ARGS'] = None +default_op_mapping_field_values['ATTR_MAPPING'] = dict() +default_op_mapping_field_values['DEFAULTS'] = dict() +default_op_mapping_field_values['INPUT_PERM'] = None +default_op_mapping_field_values['OUTPUT_PERM'] = None +default_op_mapping_field_values['FILL_NAME_FIELD'] = True +default_op_mapping = { + 'Gather': ['gather', ['X'], ['Out'], + dict(axis='')], + 'Shape': ['shape', ['X'], ['Out']], + 'Mul': ['elementwise_mul', ['X', 'Y'], ['Out'], + dict(), + dict(axis=-1)], + 'Clip': [ + 'clip', ['X'], ['Out'], + dict(), + dict( + min=(_np.asarray([255, 255, 127, 255], + dtype=_np.uint8).view(_np.float32)), + max=(_np.asarray([255, 255, 127, 127], + dtype=_np.uint8).view(_np.float32)), + ) + ], + 'ReduceMean': [ + 'reduce_mean', ['X'], ['Out'], + dict(axes='dim', keepdims='keep_dim'), + dict(keep_dim=1) + ] +} + +default_ioa_constraint = { + 'Gather': + [(lambda i, o, a: a.get('axis', 0) == 0, 'only axis = 0 is supported')], +} diff --git a/x2paddle/op_mapper/onnx_op_mapper.py b/x2paddle/op_mapper/onnx_op_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..e15ec2191b28d9fce7b3e65965ae16a3561de73d --- /dev/null +++ b/x2paddle/op_mapper/onnx_op_mapper.py @@ -0,0 +1,729 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from x2paddle.core.graph import GraphNode +from x2paddle.core.op_mapper import OpMapper +from x2paddle.core.util import * +from x2paddle.core.fluid_code import Layer +from x2paddle.core.fluid_code import FluidCode +from x2paddle.decoder.onnx_decoder import ONNXGraph, ONNXGraphNode, ONNXGraphDataNode +from x2paddle.op_mapper.onnx_directly_map import default_op_mapping_field_values +from x2paddle.op_mapper.onnx_directly_map import default_op_mapping +from x2paddle.op_mapper.onnx_directly_map import default_ioa_constraint +import numpy as np +import logging as _logging +from collections import OrderedDict as _dict + +_logger = _logging.getLogger(__name__) + + +def _const_weight_or_none(node): + if 'Constant' in node.layer_name: + return val.value + if isinstance(node, ONNXGraphDataNode): + return node.weight + return None + + +class ONNXOpMapper(OpMapper): + def __init__(self, decoder): + super(ONNXOpMapper, self).__init__() + self.decoder = decoder + self.graph = decoder.onnx_graph + self.input_shapes = [] + self.weights = dict() + self.omit_nodes = list() + + if not self.op_checker(): + raise Exception("Model are not supported yet.") + + #mapping op + + print("Total nodes: {}".format( + sum([ + isinstance(node, ONNXGraphNode) + for name, node in self.graph.node_map.items() + ]))) + for node_name in self.graph.topo_sort: + node = self.graph.get_node(node_name) + op = node.layer_type + if hasattr(self, op): + func = getattr(self, op) + func(node) + elif op in default_op_mapping: + self.directly_map(node) + + def op_checker(self): + unsupported_ops = set() + for node_name in self.graph.topo_sort: + node = self.graph.get_node(node_name) + op = node.layer_type + if not hasattr(self, op) and op not in default_op_mapping: + unsupported_ops.add(op) + if len(unsupported_ops) == 0: + return True + else: + print("There are {} ops not supported yet, list as below".format( + len(unsupported_ops))) + for op in unsupported_ops: + print(op) + return False + + def directly_map(self, node, *args, name='', **kwargs): + inputs = node.layer.input + outputs = node.layer.output + op_type = node.layer_type + attrs = node.attr_map + + info = default_op_mapping[op_type] + info.extend(list(default_op_mapping_field_values.values())[len(info):]) + ( + fluid_op, + fluid_input_args, + fluid_output_args, + attr_mapping, + default_attrs, + input_perm, + output_perm, + fill_name_field, + ) = info + + if fluid_op in default_ioa_constraint: + for predicate, message in default_ioa_constraint[fluid_op]: + assert predicate(inputs, outputs, attrs), message + + mapped_attrs = { + attr_mapping.get(key, key): value + for key, value in attrs.items() + } + if '' in mapped_attrs: + mapped_attrs.pop('') + if '_' in mapped_attrs: + mapped_attrs.pop('_') + fluid_attrs = default_attrs.copy() + fluid_attrs.update(mapped_attrs) + val_inps = inputs if input_perm is None else list( + map(lambda i: inputs[i], input_perm)) + val_outs = outputs if output_perm is None else list( + map(lambda i: outputs[i], output_perm)) + attr = fluid_attrs + if fluid_op not in ['shape', 'gather']: + attr['name'] = string(node.layer_name) + node.fluid_code.add_layer(fluid_op, + inputs=', '.join(val_inps), + output=val_outs[0], + param_attr=attr) + + def place_holder(self, node): + self.input_shapes.append(node.out_shapes) + attr = { + "dtype": string(node.dtype), + "shape": node.out_shapes, + "name": string(node.layer_name), + "append_batch_size": 'False' + } + + node.fluid_code.add_layer("data", + inputs=None, + output=node, + param_attr=attr) + + def create_parameter(self, node, parameter=None): + if parameter is not None: + node = parameter + dtype = node.dtype + shape = node.out_shapes + + self.weights[node.layer_name] = node.weight + attr = { + 'dtype': string(dtype), + 'shape': shape, + 'name': string(node.layer_name), + 'attr': string(node.layer_name), + 'default_initializer': 'Constant(0.0)' + } + node.fluid_code.add_layer("create_parameter", + inputs=None, + output=node, + param_attr=attr) + + def _pad_if_asymmetric(self, node, pads, val_name): # pads: SSEE + assert len(pads) & 1 == 0 + symmetric = True + ndims = len(pads) // 2 + for idx_dim in range(ndims): + if pads[idx_dim] != pads[ndims + idx_dim]: + symmetric = False + break + if symmetric: + return pads[:ndims], val_name + val_padded = self.Pad(node, op_independent=False) + return [0] * ndims, val_padded + + def Pad(self, node, op_independent=True): + val_x = self.graph.get_node(node.layer.input[0], copy=True) + pads = node.get_attr('pads') + mode = node.get_attr('mode', 'constant') + value = node.get_attr('value', 0.) + data_shape = val_x.out_shapes + output_shape = node.out_shapes + assume_pad2d = False + attr = {} + if len(pads) == 4: + assume_pad2d |= mode != 'constant' + if data_shape: + assume_pad2d |= data_shape and len(data_shape) == 4 # NCHW + if output_shape: + assume_pad2d |= output_shape and len(output_shape) == 4 # NCHW + if assume_pad2d: + fluid_op = 'pad2d' + attr['data_format'] = string('NCHW') + attr['mode'] = string(mode) + else: + attr = {'pad_value': value} + assert mode == 'constant', 'mode {} is supported only in pad2d'.format( + mode) + fluid_op = 'pad' + if len(pads) == 4: + paddings = np.array(pads).reshape( + (-1, 2)).transpose().flatten().tolist() # SSEE -> SESE + elif len(pads) == 8: + paddings = np.array(pads).reshape( + (-1, 4)).transpose().flatten().tolist() # SSEE -> SESE + attr['paddings'] = paddings + if op_independent: + attr['name'] = string(node.layer_name) + node.fluid_code.add_layer(fluid_op, + inputs=val_x, + output=node, + param_attr=attr) + else: + attr['name'] = string(node.layer_name + '_paded') + node.fluid_code.add_layer(fluid_op, + inputs=val_x, + output=node.layer_name + '_paded', + param_attr=attr) + return node.layer_name + '_paded' + + def Unsqueeze(self, node): + val_x = self.graph.get_node(node.layer.input[0], copy=True) + axes = node.get_attr('axes') + attr = {'axes': axes, 'name': string(node.layer_name)} + node.fluid_code.add_layer('unsqueeze', + inputs=val_x, + output=node, + param_attr=attr) + + def Constant(self, node): + val_output = self.graph.get_node(node.layer.output[0], copy=True) + + value = node.get_attr('value') + dtype = np.dtype(value.dtype) + output_dtype = val_output.dtype + if output_dtype: + assert dtype == output_dtype, 'tensor dtype unmatches storage dtype' + + shape = node.get_attr('shape', None) + if shape is None: + shape = val_output.out_shapes + if shape is None: + shape = list(value.shape) + _logger.warning( + 'in (Constant -> %s): ' + 'attribute "shape" of %s not inferred, ' + 'using value as 1-D tensor may lead to fails', + val_output.layer_name, val_output.layer_name) + + value = value.tolist() + if len(value) == 1: # scalar + shape = [1] + value = value[0] + if dtype.name == 'int64': + dtype = 'int32' + attr = {'shape': shape, 'dtype': string(dtype), 'value': value} + node.fluid_code.add_layer('fill_constant', + inputs=None, + output=node, + param_attr=attr) + + def Resize(self, node): + # I/O + val_x = self.graph.get_node(node.layer.input[0], copy=True) + val_scales = self.graph.get_node(node.layer.input[1], copy=True) + val_y, = self.graph.get_node(node.layer.output[0], copy=True) + + out_shape_ = val_y.out_shapes + if out_shape_ is not None: + assert len(out_shape_) == 4, 'only 4-D Tensor as X and Y supported' + out_shape_ = out_shape_[2:] + scales = _const_weight_or_none(val_scales) + if scales is not None: + assert len(scales) == 4, 'only 4-D Tensor as X and Y supported' + assert scales[0] == 1 and scales[ + 1] == 1, 'only scale on (NC)HW supported' + assert scales[2] == scales[ + 3], 'only aspect-ratio-invariant scale supported' + scale = scales[2] if scales else None + if scale is None: + assert out_shape_, 'neither scales nor output shape is available' + out_shape = out_shape_ + else: + out_shape = None + if out_shape_ is None: + in_shape = val_x.out_shapes + assert in_shape is not None, 'out_shape required but not inferrable' + assert len( + in_shape) == 4, 'only 4-D Tensor as X and Y supported' + out_shape_ = [in_shape[2] * scale, in_shape[3] * scale] + + mode = node.get_attr('mode', 'nearest') + fluid_op = 'resize_{}'.format(mode) + name_attr = ', name={}'.format(repr(name)) if name else '' + + attr = { + 'scale': scale, + 'out_shape': out_shape, + 'name': string(node.layer_name) + } + node.fluid_code.add_layer(fluid_op, + inputs=val_x, + output=node, + param_attr=attr) + + def ConstantOfShape(self, node): + val_shape = self.graph.get_node(node.layer.input[0], copy=True) + + shape = _const_weight_or_none(val_shape) + + if shape is None: + shape = node.out_shapes + + assert shape is not None, ( + 'given shape is neither const value nor deductible from output, ' + 'this is not supported') + + value = node.get_attr('value') + dtype = value.dtype + value = value.tolist() + if len(value) == 1: + shape = [1] + value = value[0] + if dtype.name == 'int64': + dtype = 'int32' + attr = {'shape': shape, 'dtype': string(dtype), 'value': value} + node.fluid_code.add_layer('fill_constant', + inputs=None, + output=node, + param_attr=attr) + + def Split(self, node): + val_input = self.graph.get_node(node.layer.input[0], copy=True) + var_outs = [val for val in node.layer.input] + + fluid_op = 'split' + split = node.get_attr['split'] + axis = node.get_attr('axis', 0) + attr = {'split': split, 'axis': axis, 'name': string(node.layer_name)} + # generation + node.fluid_code.add_layer('split', + inputs=val_input, + output=var_outs, + param_attr=attr) + + def Reshape(self, node): + val_x = self.graph.get_node(node.layer.input[0], copy=True) + val_shape = self.graph.get_node(node.layer.input[1], copy=True) + val_reshaped = self.graph.get_node(node.layer.output[0], copy=True) + shape = None + if isinstance(val_shape, ONNXGraphDataNode): + self.omit_nodes.append(val_shape.layer_name) + + # catch dynamic graph shape + if isinstance(val_shape, ONNXGraphNode): + shape = self.decoder.get_dynamic_shape_from_caffe2( + val_shape.layer_name, self.input_shapes) + if shape is None: + shape = val_reshaped.out_shapes + + shape_dtype = val_shape.dtype + + if shape_dtype is None: + _logger.warning( + 'in op %s(%s -> Reshape -> %s): ' + 'dtype of input "shape" not inferred, int32 assumed', + node.layer_name, val_x.layer_name, val_reshaped.layer_name) + shape_dtype = _np.dtype('int32') + if shape is None: + shape = [1, -1] + _logger.warning( + 'in %s(%s -> Reshape -> %s): ' + 'input "shape" not inferred, use [1, -1] as dummy value, ' + 'the behavior of Paddle fluid maybe undefined', node.layer_name, + val_x.layer_name, val_reshaped.layer_name) + attr = {'shape': shape, 'name': string(node.layer_name)} + + node.fluid_code.add_layer('reshape', + inputs=val_x, + output=node, + param_attr=attr) + + def Cast(self, node): + val_input = self.graph.get_node(node.layer.input[0], copy=True) + val_output = self.graph.get_node(node.layer.output[0], copy=True) + + dtype = node.get_attr('to') + if not isinstance(dtype, np.dtype): + dtype = TENSOR_TYPE_TO_NP_TYPE[dtype] + + output_dtype = val_output.dtype + if output_dtype: + assert dtype == output_dtype, 'dtype of to unmatches output' + attr = {'dtype': string(dtype)} + node.fluid_code.add_layer('cast', + inputs=val_input, + output=node, + param_attr=attr) + + def AveragePool(self, node): + val_x = self.graph.get_node(node.layer.input[0], copy=True) + assert node.get_attr( + 'auto_pad', + 'NOTSET') == 'NOTSET', 'only auto_pad = NOTSET is supported' + kernel_shape = node.get_attr("kernel_shape") + poolnd = len(kernel_shape) + strides = node.get_attr("strides") + pad_mode = node.get_attr("pads") + ceil_mode = bool(node.get_attr('ceil_mode', 0)) + pads = node.get_attr('pads', [0] * (poolnd * 2)) + fluid_op = 'pool{}d'.format(poolnd) + assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported' + + paddings, val_x = self._pad_if_asymmetric(node, pads, val_x) + attr = { + "pool_size": kernel_shape, + "pool_type": string('avg'), + "pool_stride": strides, + "pool_padding": paddings, + "ceil_mode": ceil_mode, + "exclusive": 'True', + "name": string(node.layer_name) + } + + node.fluid_code.add_layer(fluid_op, + inputs=val_x, + output=node, + param_attr=attr) + + def Concat(self, node): + inputs = [] + for i in range(len(node.layer.input)): + ipt = self.graph.get_node(node.layer.input[i], copy=True) + if isinstance(ipt, str): + inputs.append(ipt) + else: + inputs.append(ipt.layer_name) + axis = node.get_attr('axis') + attr = {'axis': axis} + node.fluid_code.add_layer('concat', + inputs='[' + ', '.join(inputs) + ']', + output=node, + param_attr=attr) + + def Flatten(self, node): + val_x = self.graph.get_node(node.layer.input[0], copy=True) + axis = node.get_attr('axis', 1) + attr = {"axis": str(axis), "name": string(node.layer_name)} + node.fluid_code.add_layer('flatten', + inputs=val_x, + output=node, + param_attr=attr) + + def Gemm(self, node): + val_a = self.graph.get_node(node.layer.input[0], copy=True) + val_b = self.graph.get_node(node.layer.input[1], copy=True) + val_c = self.graph.get_node(node.layer.input[2], copy=True) + + alpha = node.get_attr('alpha', 1.) # optional + beta = node.get_attr('beta', 1.) # optional + trans_a = bool(node.get_attr('transA', 0)) # optional + trans_b = bool(node.get_attr('transB', 0)) # optional + val_mm = node.layer_name + '_mm' + matmul_inputs = {"x": val_a, "y": val_b} + attr_matmul = { + "transpose_x": trans_a, + "transpose_y": trans_b, + "alpha": alpha, + "name": string(val_mm) + } + node.fluid_code.add_layer('matmul', + inputs=matmul_inputs, + output=val_mm, + param_attr=attr_matmul) + + if beta != 0: + if beta == 1.: + add_inputs = {"x": val_mm, "y": val_c} + attr = {"name": string(node.layer_name)} + node.fluid_code.add_layer("elementwise_add", + inputs=add_inputs, + output=node, + param_attr=attr) + else: + var_beta = node.layer_name + '_beta' + matmul_beta_inputs = {"x": val_c, "y": var_beta} + node.fluid_code.add_layer("Constant", + inputs=matmul_beta_inputs, + output=var_beta, + param_attr={'value': beta}) + + add_inputs = {"x": val_mm, "y": var_beta} + attr = {"name": string(node.layer_name)} + node.fluid_code.add_layer("elementwise_add", + inputs=add_inputs, + output=node, + param_attr=attr) + + def Add(self, node): + val_x = self.graph.get_node(node.layer.input[0], copy=True) + val_y = self.graph.get_node(node.layer.input[1], copy=True) + inputs = { + "x": val_x, + "y": val_y, + } + attr = {"name": string(node.layer_name)} + node.fluid_code.add_layer("elementwise_add", + inputs=inputs, + output=node, + param_attr=attr) + + def Sum(self, node): + var_inps = [val for val in node.layer.input] + node.fluid_code.add_layer("sum", + inputs='[' + ', '.join(var_inps) + ']', + output=node) + + def MatMul(self, node): + val_x = self.graph.get_node(node.layer.input[0], copy=True) + val_y = self.graph.get_node(node.layer.input[1], copy=True) + inputs = {"x": val_x, "y": val_y} + attr = {"name": string(node.layer_name)} + node.fluid_code.add_layer("matmul", + inputs=inputs, + output=node, + param_attr=attr) + + def BatchNormalization(self, node): + val_x = self.graph.get_node(node.layer.input[0], copy=True) + val_scale = self.graph.get_node(node.layer.input[1], copy=True) + val_b = self.graph.get_node(node.layer.input[2], copy=True) + val_mean = self.graph.get_node(node.layer.input[3], copy=True) + val_var = self.graph.get_node(node.layer.input[4], copy=True) + + self.omit_nodes.append(val_scale.layer_name) + self.omit_nodes.append(val_b.layer_name) + self.omit_nodes.append(val_mean.layer_name) + self.omit_nodes.append(val_var.layer_name) + + momentum = node.get_attr('momentum', .9) + epsilon = node.get_attr('epsilon', 1e-5) + + # Attribute: spatial is used in BatchNormalization-1,6,7 + spatial = bool(node.get_attr('spatial')) + attr = { + "momentum": momentum, + "epsilon": epsilon, + "data_layout": string('NCHW'), + "is_test": True, + "param_attr": string(val_scale.layer_name), + "bias_attr": string(val_b.layer_name), + "moving_mean_name": string(val_mean.layer_name), + "moving_variance_name": string(val_var.layer_name), + "use_global_stats": spatial, + "name": string(node.layer_name) + } + node.fluid_code.add_layer("batch_norm", + inputs=val_x, + output=node, + param_attr=attr) + + def Softmax(self, node): + val_x = self.graph.get_node(node.layer.input[0], copy=True) + attr = {"name": string(node.layer_name)} + node.fluid_code.add_layer("softmax", + inputs=val_x, + output=node, + param_attr=attr) + + def Transpose(self, node): + val_x = self.graph.get_node(node.layer.input[0], copy=True) + perm = node.get_attr('perm') + attr = {'perm': perm, "name": string(node.layer_name)} + node.fluid_code.add_layer("transpose", + inputs=val_x, + output=node, + param_attr=attr) + + def Div(self, node): + val_x = self.graph.get_node(node.layer.input[0], copy=True) + val_y = self.graph.get_node(node.layer.input[1], copy=True) + inputs = {'x': val_x, 'y': val_y} + attr = {"name": string(node.layer_name)} + node.fluid_code.add_layer("elementwise_div", + inputs=inputs, + output=node, + param_attr=attr) + + def Relu(self, node): + val_x = self.graph.get_node(node.layer.input[0], copy=True) + attr = {"name": string(node.layer_name)} + node.fluid_code.add_layer("relu", + inputs=val_x, + output=node, + param_attr=attr) + + def PRelu(self, node): + val_x = self.graph.get_node(node.layer.input[0], copy=True) + val_slope = self.graph.get_node(node.layer.input[1], copy=True) + attr = {"name": string(node.layer_name), "mode": string('channel')} + + if isinstance(val_slope, str): + attr["param_attr"] = string(val_slope.layer_name) + else: + attr["param_attr"] = string(val_slope.layer_name) + node.fluid_code.add_layer("prelu", + inputs=val_x, + output=node, + param_attr=attr) + + def Squeeze(self, node): + val_x = self.graph.get_node(node.layer.input[0], copy=True) + squeeze_dims = node.get_attr('squeeze_dims') + attr = {'axes': squeeze_dims, "name": string(node.layer_name)} + node.fluid_code.add_layer("squeeze", + inputs=val_x, + output=node, + param_attr=attr) + + def Identity(self, node): + val_x = self.graph.get_node(node.layer.input[0], copy=True) + node.fluid_code.add_layer("assign", inputs=val_x, output=node) + + def MaxPool(self, node): + val_x = self.graph.get_node(node.layer.input[0], copy=True) + assert node.get_attr( + 'auto_pad', 'NOTSET' + ) == 'NOTSET', 'only auto_pad = NOTSET is supported' # optional + + assert node.get_attr( + "dilations") is None, 'only dilations = 0 is supported' # optional + + kernel_shape = node.get_attr("kernel_shape") + poolnd = len(kernel_shape) + strides = node.get_attr("strides") + pad_mode = node.get_attr("pads") + ceil_mode = bool(node.get_attr('ceil_mode', 0)) # optional + pads = node.get_attr('pads', [0] * (poolnd * 2)) # optional + fluid_op = 'pool{}d'.format(poolnd) + assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported' + paddings, val_x = self._pad_if_asymmetric(node, pads, val_x) + attr = { + "pool_size": kernel_shape, + "pool_type": string("max"), + "pool_stride": strides, + "pool_padding": paddings, + "ceil_mode": ceil_mode, + "name": string(node.layer_name), + "exclusive": False + } + node.fluid_code.add_layer(fluid_op, + inputs=val_x, + output=node, + param_attr=attr) + + def GlobalAveragePool(self, node): + val_x = self.graph.get_node(node.layer.input[0], copy=True) + val_y = self.graph.get_node(node.layer.output[0], copy=True) + input_shape = val_x.out_shapes + output_shape = val_y.out_shapes + assert input_shape is not None or output_shape is not None, 'poolnd not inferred' # N + if input_shape: + poolnd = len(input_shape) - 2 # NC... + elif output_shape: + poolnd = len(output_shape) - 2 # NC... + assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported' + fluid_op = 'pool{}d'.format(poolnd) + attr = { + "pool_type": string("avg"), + "global_pooling": True, + "name": string(node.layer_name) + } + node.fluid_code.add_layer(fluid_op, + inputs=val_x, + output=node, + param_attr=attr) + + def Conv(self, node): + val_x = self.graph.get_node(node.layer.input[0], copy=True) + val_w = self.graph.get_node(node.layer.input[1], copy=True) + val_y = self.graph.get_node(node.layer.output[0], copy=True) + + self.omit_nodes.append(val_w.layer_name) + input_shape = val_x.out_shapes + + has_bias = len(node.layer.input) == 3 + if has_bias: + val_b = self.graph.get_node(node.layer.input[2], copy=True) + self.omit_nodes.append(val_b.layer_name) + auto_pad = node.get_attr('auto_pad', 'NOTSET') + + kernel_shape = val_w.out_shapes[2:] # OI... + assert kernel_shape == node.get_attr( + 'kernel_shape'), 'kernel_shape in attr unmatches value_info' # HW + convnd = len(kernel_shape) + assert 2 <= convnd <= 3, 'only conv2d and conv3d is supported' + num_out_channels = val_w.out_shapes[0] # OI... + fluid_op = 'conv{}d'.format(convnd) + + num_groups = node.get_attr('group', 1) + strides = node.get_attr('strides', [1] * convnd) # optional + dilations = node.get_attr('dilations', [1] * convnd) # optional + pads = node.get_attr('pads', [0] * (convnd * 2)) # optional + + paddings, val_x = self._pad_if_asymmetric(node, pads, val_x) + + if auto_pad == "SAME_UPPER" or auto_pad == "SAME_UPPER": + pad_h = get_same_padding(input_shape[2], kernel_shape[0], + strides[0]) + pad_w = get_same_padding(input_shape[3], kernel_shape[1], + strides[1]) + attr = {"paddings": pad_h + pad_w, "pad_value": 0.0} + + attr = { + "num_filters": num_out_channels, + "filter_size": kernel_shape, + "stride": strides, + "padding": paddings, + "dilation": dilations, + "groups": num_groups, + 'param_attr': string(val_w.layer_name), + "name": string(node.layer_name) + } + if has_bias: + attr["bias_attr"] = string(val_b.layer_name) + else: + attr["bias_attr"] = False + node.fluid_code.add_layer(fluid_op, + inputs=val_x, + output=node, + param_attr=attr) diff --git a/x2paddle/optimizer/onnx_optimizer.py b/x2paddle/optimizer/onnx_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..28ffd0fdca60b353eb2881418f5d5cd1c507b5da --- /dev/null +++ b/x2paddle/optimizer/onnx_optimizer.py @@ -0,0 +1,31 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO useless node remove +from x2paddle.op_mapper.onnx_op_mapper import ONNXOpMapper +from x2paddle.core.util import * + + +class ONNXOptimizer(object): + def __init__(self, op_mapper): + self.op_mapper = op_mapper + self.graph = op_mapper.graph + + def delete_redundance_code(self): + for node_name in self.graph.topo_sort: + if node_name in self.op_mapper.omit_nodes: + node = self.graph.get_node(node_name) + omit_freq = self.op_mapper.omit_nodes.count(node_name) + if len(node.outputs) <= omit_freq: + node.fluid_code.clear() diff --git a/x2paddle_model_zoo.md b/x2paddle_model_zoo.md index 335e28baacc9717cba122c4b47cc701ce8075c21..6aaf7499762dca1d44f1050b0b07767de7ca85d9 100644 --- a/x2paddle_model_zoo.md +++ b/x2paddle_model_zoo.md @@ -26,3 +26,33 @@ | ShuffleNet | [code](https://github.com/miaow1988/ShuffleNet_V2_pytorch_caffe/releases/tag/v0.1.0) | | mNASNet | [code](https://github.com/LiJianfei06/MnasNet-caffe) | | MTCNN | [code](https://github.com/kpzhang93/MTCNN_face_detection_alignment/tree/master/code/codes/MTCNNv1/model) | + +# ONNX + +| 模型 | 来源 | operator version| +|-------|--------|---------| +| Resnet18 | [torchvison.model.resnet18](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py) |9| +| Resnet34 | [torchvison.model.resnet34](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py) |9| +| Resnet50 | [torchvison.model.resnet50](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py) |9| +| Resnet101 | [torchvison.model.resnet101](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py) |9| +| Vgg11 | [torchvison.model.vgg11](https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py) |9| +| Vgg11_bn | [torchvison.model.vgg11_bn](https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py) |9| +| Vgg19| [torchvison.model.vgg19](https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py) |9| +| Densenet121 | [torchvison.model.densenet121](https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py) |9| +| Alexnet | [torchvison.model.alexnet](https://github.com/pytorch/vision/blob/master/torchvision/models/alexnet.py) |9| +| Shufflenet | [onnx official](https://github.com/onnx/models/tree/master/vision/classification/shufflenet) |9| +| Inception_v2 | [onnx official](https://github.com/onnx/models/tree/master/vision/classification/inception_and_googlenet/inception_v2) |9| + +目前onnx2paddle主要支持onnx operator version 9,关于如何使用torchvison的model: +``` +import torch +import torchvision + +#根据不同模型调整输入的shape +dummy_input = torch.randn(1, 3, 224, 224) +resnet18 = torchvision.models.resnet18(pretrained=True) + +#"resnet18.onnx"为onnx model的存储路径 +torch.onnx.export(resnet18, dummy_input, "resnet18.onnx",verbose=True) + +```