提交 7b73b705 编写于 作者: C Channingss

update define input shape

上级 71b51e82
...@@ -165,7 +165,10 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False): ...@@ -165,7 +165,10 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False):
mapper.save_inference_model(save_dir, params_merge) mapper.save_inference_model(save_dir, params_merge)
def onnx2paddle(model_path, save_dir, params_merge=False): def onnx2paddle(model_path,
save_dir,
define_input_shape=False,
params_merge=False):
# check onnx installation and version # check onnx installation and version
try: try:
import onnx import onnx
...@@ -181,7 +184,7 @@ def onnx2paddle(model_path, save_dir, params_merge=False): ...@@ -181,7 +184,7 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper
from x2paddle.decoder.onnx_decoder import ONNXDecoder from x2paddle.decoder.onnx_decoder import ONNXDecoder
from x2paddle.optimizer.onnx_optimizer import ONNXOptimizer from x2paddle.optimizer.onnx_optimizer import ONNXOptimizer
model = ONNXDecoder(model_path) model = ONNXDecoder(model_path, define_input_shape=define_input_shape)
mapper = ONNXOpMapper(model) mapper = ONNXOpMapper(model)
print("Model optimizing ...") print("Model optimizing ...")
optimizer = ONNXOptimizer(mapper) optimizer = ONNXOptimizer(mapper)
...@@ -262,11 +265,13 @@ def main(): ...@@ -262,11 +265,13 @@ def main():
args.caffe_proto, params_merge) args.caffe_proto, params_merge)
elif args.framework == "onnx": elif args.framework == "onnx":
assert args.model is not None, "--model should be defined while translating onnx model" assert args.model is not None, "--model should be defined while translating onnx model"
define_input_shape = False
params_merge = False params_merge = False
if args.define_input_shape:
define_input_shape = True
if args.params_merge: if args.params_merge:
params_merge = True params_merge = True
onnx2paddle(args.model, args.save_dir, params_merge) onnx2paddle(args.model, args.save_dir, define_input_shape, params_merge)
elif args.framework == "paddle2onnx": elif args.framework == "paddle2onnx":
assert args.model is not None, "--model should be defined while translating paddle model to onnx" assert args.model is not None, "--model should be defined while translating paddle model to onnx"
......
...@@ -18,7 +18,7 @@ from x2paddle.decoder.onnx_shape_inference import SymbolicShapeInference ...@@ -18,7 +18,7 @@ from x2paddle.decoder.onnx_shape_inference import SymbolicShapeInference
from onnx.checker import ValidationError from onnx.checker import ValidationError
from onnx.checker import check_model from onnx.checker import check_model
from onnx.utils import polish_model from onnx.utils import polish_model
from onnx import helper from onnx import helper, shape_inference
from onnx.helper import get_attribute_value, make_attribute from onnx.helper import get_attribute_value, make_attribute
from onnx.shape_inference import infer_shapes from onnx.shape_inference import infer_shapes
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
...@@ -29,11 +29,9 @@ import onnx ...@@ -29,11 +29,9 @@ import onnx
from onnx.helper import ValueInfoProto from onnx.helper import ValueInfoProto
import numpy as np import numpy as np
from copy import deepcopy from copy import deepcopy
import logging as _logging
import os import os
default_op_domain = 'ai.onnx' default_op_domain = 'ai.onnx'
_logger = _logging.getLogger(__name__)
class ONNXGraphNode(GraphNode): class ONNXGraphNode(GraphNode):
...@@ -130,18 +128,16 @@ class ONNXGraphDataNode(GraphNode): ...@@ -130,18 +128,16 @@ class ONNXGraphDataNode(GraphNode):
class ONNXGraph(Graph): class ONNXGraph(Graph):
def __init__(self, onnx_model): def __init__(self, onnx_model, define_input_shape=False):
super(ONNXGraph, self).__init__(onnx_model) super(ONNXGraph, self).__init__(onnx_model)
self.graph = onnx_model.graph
self.define_input_shape = define_input_shape
self.fixed_input_shape = {} self.fixed_input_shape = {}
self.initializer = {} self.initializer = {}
self.place_holder_nodes = list() self.place_holder_nodes = list()
self.value_infos = {} self.value_infos = {}
self.graph = onnx_model.graph
self.get_place_holder_nodes() self.get_place_holder_nodes()
print("shape inferencing ...") self.shape_inference()
self.graph = SymbolicShapeInference.infer_shapes(
onnx_model, fixed_input_shape=self.fixed_input_shape)
print("shape inferenced.")
self.build() self.build()
self.collect_value_infos() self.collect_value_infos()
self.allocate_shapes() self.allocate_shapes()
...@@ -152,7 +148,7 @@ class ONNXGraph(Graph): ...@@ -152,7 +148,7 @@ class ONNXGraph(Graph):
""" """
inner_nodes = [] inner_nodes = []
if not isinstance(self.graph, onnx.GraphProto): if not isinstance(self.graph, onnx.GraphProto):
logger.error('graph is not a GraphProto instance') assert 'graph is not a GraphProto instance'
return return
for initializer in self.graph.initializer: for initializer in self.graph.initializer:
name = initializer.name name = initializer.name
...@@ -168,36 +164,47 @@ class ONNXGraph(Graph): ...@@ -168,36 +164,47 @@ class ONNXGraph(Graph):
shape.append(dim.dim_value) shape.append(dim.dim_value)
return shape return shape
def check_input_shape(self, vi): def shape_inference(self):
print('shape inferencing ...')
infered_graph = SymbolicShapeInference.infer_shapes(
self.model, fixed_input_shape=self.fixed_input_shape)
if infered_graph is None:
infered_model = shape_inference.infer_shapes(self.model)
self.graph = infered_model.graph
else:
self.graph = infered_graph
print('shape inferenced.')
def is_static_shape(self, vi):
if vi.type.HasField('tensor_type'): if vi.type.HasField('tensor_type'):
for dim in vi.type.tensor_type.shape.dim: for dim in vi.type.tensor_type.shape.dim:
if dim.HasField( if dim.HasField(
'dim_param') and vi.name not in self.fixed_input_shape: 'dim_param') and vi.name not in self.fixed_input_shape:
shape = self.get_symbolic_shape( return False
vi.type.tensor_type.shape.dim) return True
print(
"Unknown shape for input tensor[tensor name: '{}'] -> shape: {}, Please define shape of input here,\nNote:you can use visualization tools like Netron to check input shape." def fix_unkown_input_shape(self, vi):
.format(vi.name, shape)) shape = self.get_symbolic_shape(vi.type.tensor_type.shape.dim)
right_shape_been_input = False print(
while not right_shape_been_input: "Unknown shape for input tensor[tensor name: '{}'] -> shape: {}, Please define shape of input here,\nNote:you can use visualization tools like Netron to check input shape."
try: .format(vi.name, shape))
shape = raw_input( right_shape_been_input = False
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: " while not right_shape_been_input:
) try:
except: shape = raw_input(
shape = input( "Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: ")
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: " except:
) shape = input(
if shape.count("-1") > 1: "Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: ")
print("Only 1 dimension can be -1, type again:)") if shape.count("-1") > 1:
else: print("Only 1 dimension can be -1, type again:)")
right_shape_been_input = True else:
if shape == 'N': right_shape_been_input = True
break if shape == 'N':
shape = [int(dim) for dim in shape.strip().split(',')] break
assert shape.count(-1) <= 1, "Only one dimension can be -1" shape = [int(dim) for dim in shape.strip().split(',')]
self.fixed_input_shape[vi.name] = shape assert shape.count(-1) <= 1, "Only one dimension can be -1"
break self.fixed_input_shape[vi.name] = shape
def get_place_holder_nodes(self): def get_place_holder_nodes(self):
""" """
...@@ -206,7 +213,8 @@ class ONNXGraph(Graph): ...@@ -206,7 +213,8 @@ class ONNXGraph(Graph):
inner_nodes = self.get_inner_nodes() inner_nodes = self.get_inner_nodes()
for ipt_vi in self.graph.input: for ipt_vi in self.graph.input:
if ipt_vi.name not in inner_nodes: if ipt_vi.name not in inner_nodes:
self.check_input_shape(ipt_vi) if self.define_input_shape:
self.check_input_shape(ipt_vi)
self.place_holder_nodes.append(ipt_vi.name) self.place_holder_nodes.append(ipt_vi.name)
def get_output_nodes(self): def get_output_nodes(self):
...@@ -310,7 +318,7 @@ class ONNXGraph(Graph): ...@@ -310,7 +318,7 @@ class ONNXGraph(Graph):
""" """
if not isinstance(self.graph, onnx.GraphProto): if not isinstance(self.graph, onnx.GraphProto):
logger.error('graph is not a GraphProto instance') assert 'graph is not a GraphProto instance'
return return
for initializer in self.graph.initializer: for initializer in self.graph.initializer:
...@@ -353,7 +361,7 @@ class ONNXGraph(Graph): ...@@ -353,7 +361,7 @@ class ONNXGraph(Graph):
class ONNXDecoder(object): class ONNXDecoder(object):
def __init__(self, onnx_model): def __init__(self, onnx_model, define_input_shape=False):
onnx_model = onnx.load(onnx_model) onnx_model = onnx.load(onnx_model)
print('model ir_version: {}, op version: {}'.format( print('model ir_version: {}, op version: {}'.format(
onnx_model.ir_version, onnx_model.opset_import[0].version)) onnx_model.ir_version, onnx_model.opset_import[0].version))
...@@ -364,7 +372,7 @@ class ONNXDecoder(object): ...@@ -364,7 +372,7 @@ class ONNXDecoder(object):
onnx_model = self.optimize_model_skip_op(onnx_model) onnx_model = self.optimize_model_skip_op(onnx_model)
onnx_model = self.optimize_model_strip_initializer(onnx_model) onnx_model = self.optimize_model_strip_initializer(onnx_model)
onnx_model = self.optimize_node_name(onnx_model) onnx_model = self.optimize_node_name(onnx_model)
self.graph = ONNXGraph(onnx_model) self.graph = ONNXGraph(onnx_model, define_input_shape)
#self.onnx_model = onnx_model #self.onnx_model = onnx_model
def build_value_refs(self, nodes): def build_value_refs(self, nodes):
......
...@@ -1585,12 +1585,9 @@ class SymbolicShapeInference: ...@@ -1585,12 +1585,9 @@ class SymbolicShapeInference:
in_mp) in_mp)
symbolic_shape_inference._update_output_from_vi() symbolic_shape_inference._update_output_from_vi()
if not all_shapes_inferred: if not all_shapes_inferred:
print('!' * 10)
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes( symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
symbolic_shape_inference.out_mp_) symbolic_shape_inference.out_mp_)
#onnx.save(symbolic_shape_inference.out_mp_, 'tmp.onnx') #onnx.save(symbolic_shape_inference.out_mp_, 'tmp.onnx')
except: except:
print('Stopping at incomplete shape inference') return None
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
symbolic_shape_inference.out_mp_)
return symbolic_shape_inference.out_mp_.graph return symbolic_shape_inference.out_mp_.graph
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册