diff --git a/README.md b/README.md index df597287c478fa151e8de7bae7e97f7841a145d9..e3d87dae7067e0cfb6278949d0d738ab665212f3 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,7 @@ x2paddle --framework=caffe --prototxt=deploy.prototxt --weight=deploy.caffemodel | --weight | 当framework为caffe时,该参数指定caffe模型的参数文件路径 | | --save_dir | 指定转换后的模型保存目录路径 | | --model | 当framework为tensorflow/onnx时,该参数指定tensorflow的pb模型文件或onnx模型路径 | +| --input_shape_dict | **[可选]** For ONNX, 定义ONNX模型输入大小 | | --caffe_proto | **[可选]** 由caffe.proto编译成caffe_pb2.py文件的存放路径,当存在自定义Layer时使用,默认为None | | --define_input_shape | **[可选]** For TensorFlow, 当指定该参数时,强制用户输入每个Placeholder的shape,见[文档Q2](./docs/inference_model_convertor/FAQ.md) | | --enable_code_optim | **[可选]** For PyTorch, 是否对生成代码进行优化,默认为False | diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 336feb36e32c70acb543edb9f193d47262750739..cadd9fd9d6a0b246c3efced681031e9596587566 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -73,6 +73,13 @@ def arg_parser(): action="store_true", default=False, help="define input shape for tf model") + parser.add_argument( + "--input_shape_dict", + "-isd", + type=_text_type, + default=None, + help="define input shapes, e.g --input_shape_dict=\"{'image':[1, 3, 608, 608]}\" or" \ + "--input_shape_dict=\"{'image':[1, 3, 608, 608], 'im_shape': [1, 2], 'scale_factor': [1, 2]}\"") parser.add_argument( "--convert_torch_project", "-tp", @@ -265,6 +272,7 @@ def caffe2paddle(proto_file, def onnx2paddle(model_path, save_dir, + input_shape_dict=None, convert_to_lite=False, lite_valid_places="arm", lite_model_type="naive_buffer", @@ -292,7 +300,7 @@ def onnx2paddle(model_path, from x2paddle.decoder.onnx_decoder import ONNXDecoder from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper - model = ONNXDecoder(model_path, enable_onnx_checker) + model = ONNXDecoder(model_path, input_shape_dict, enable_onnx_checker) mapper = ONNXOpMapper(model) mapper.paddle_graph.build() logging.info("Model optimizing ...") @@ -481,6 +489,7 @@ def main(): onnx2paddle( args.model, args.save_dir, + input_shape_dict=args.input_shape_dict, convert_to_lite=args.to_lite, lite_valid_places=args.lite_valid_places, lite_model_type=args.lite_model_type, diff --git a/x2paddle/decoder/onnx_decoder.py b/x2paddle/decoder/onnx_decoder.py index 57e5cbe388571959aab5dcaa974bb5c6979344a6..0514d19d8c8ec573507eaf034006b618529dcb8c 100755 --- a/x2paddle/decoder/onnx_decoder.py +++ b/x2paddle/decoder/onnx_decoder.py @@ -173,9 +173,12 @@ class ONNXGraphDataNode(GraphNode): class ONNXGraph(Graph): - def __init__(self, onnx_model): + def __init__(self, onnx_model, input_shape_dict): super(ONNXGraph, self).__init__(onnx_model) self.fixed_input_shape = {} + if input_shape_dict is not None: + for k, v in eval(input_shape_dict).items(): + self.fixed_input_shape["x2paddle_" + k] = v self.initializer = {} self.place_holder_nodes = list() self.value_infos = {} @@ -216,37 +219,6 @@ class ONNXGraph(Graph): shape.append(dim.dim_value) return shape - def check_input_shape(self, vi): - if vi.type.HasField('tensor_type'): - for dim in vi.type.tensor_type.shape.dim: - if dim.HasField( - 'dim_param') and vi.name not in self.fixed_input_shape: - shape = self.get_symbolic_shape( - vi.type.tensor_type.shape.dim) - print( - "Unknown shape for input tensor[tensor name: '{}'] -> shape: {}, Please define shape of input here,\nNote:you can use visualization tools like Netron to check input shape." - .format(vi.name, shape)) - right_shape_been_input = False - while not right_shape_been_input: - try: - shape = raw_input( - "Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: " - ) - except NameError: - shape = input( - "Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: " - ) - if shape.count("-1") > 1: - print("Only 1 dimension can be -1, type again:)") - else: - right_shape_been_input = True - if shape == 'N': - break - shape = [int(dim) for dim in shape.strip().split(',')] - assert shape.count(-1) <= 1, "Only one dimension can be -1" - self.fixed_input_shape[vi.name] = shape - break - def get_place_holder_nodes(self): """ generate place_holder node of ONNX model @@ -254,7 +226,6 @@ class ONNXGraph(Graph): inner_nodes = self.get_inner_nodes() for ipt_vi in self.graph.input: if ipt_vi.name not in inner_nodes: - self.check_input_shape(ipt_vi) self.place_holder_nodes.append(ipt_vi.name) def get_output_nodes(self): @@ -416,7 +387,7 @@ class ONNXGraph(Graph): class ONNXDecoder(object): - def __init__(self, onnx_model, enable_onnx_checker): + def __init__(self, onnx_model, input_shape_dict, enable_onnx_checker): onnx_model = onnx.load(onnx_model) print('model ir_version: {}, op version: {}'.format( onnx_model.ir_version, onnx_model.opset_import[0].version)) @@ -427,7 +398,7 @@ class ONNXDecoder(object): onnx_model = self.optimize_model_skip_op(onnx_model) onnx_model = self.optimize_node_name(onnx_model) - self.graph = ONNXGraph(onnx_model) + self.graph = ONNXGraph(onnx_model, input_shape_dict) def build_value_refs(self, nodes): """