提交 7b73b705 编写于 作者: C Channingss

update define input shape

上级 71b51e82
......@@ -165,7 +165,10 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False):
mapper.save_inference_model(save_dir, params_merge)
def onnx2paddle(model_path, save_dir, params_merge=False):
def onnx2paddle(model_path,
save_dir,
define_input_shape=False,
params_merge=False):
# check onnx installation and version
try:
import onnx
......@@ -181,7 +184,7 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper
from x2paddle.decoder.onnx_decoder import ONNXDecoder
from x2paddle.optimizer.onnx_optimizer import ONNXOptimizer
model = ONNXDecoder(model_path)
model = ONNXDecoder(model_path, define_input_shape=define_input_shape)
mapper = ONNXOpMapper(model)
print("Model optimizing ...")
optimizer = ONNXOptimizer(mapper)
......@@ -262,11 +265,13 @@ def main():
args.caffe_proto, params_merge)
elif args.framework == "onnx":
assert args.model is not None, "--model should be defined while translating onnx model"
define_input_shape = False
params_merge = False
if args.define_input_shape:
define_input_shape = True
if args.params_merge:
params_merge = True
onnx2paddle(args.model, args.save_dir, params_merge)
onnx2paddle(args.model, args.save_dir, define_input_shape, params_merge)
elif args.framework == "paddle2onnx":
assert args.model is not None, "--model should be defined while translating paddle model to onnx"
......
......@@ -18,7 +18,7 @@ from x2paddle.decoder.onnx_shape_inference import SymbolicShapeInference
from onnx.checker import ValidationError
from onnx.checker import check_model
from onnx.utils import polish_model
from onnx import helper
from onnx import helper, shape_inference
from onnx.helper import get_attribute_value, make_attribute
from onnx.shape_inference import infer_shapes
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
......@@ -29,11 +29,9 @@ import onnx
from onnx.helper import ValueInfoProto
import numpy as np
from copy import deepcopy
import logging as _logging
import os
default_op_domain = 'ai.onnx'
_logger = _logging.getLogger(__name__)
class ONNXGraphNode(GraphNode):
......@@ -130,18 +128,16 @@ class ONNXGraphDataNode(GraphNode):
class ONNXGraph(Graph):
def __init__(self, onnx_model):
def __init__(self, onnx_model, define_input_shape=False):
super(ONNXGraph, self).__init__(onnx_model)
self.graph = onnx_model.graph
self.define_input_shape = define_input_shape
self.fixed_input_shape = {}
self.initializer = {}
self.place_holder_nodes = list()
self.value_infos = {}
self.graph = onnx_model.graph
self.get_place_holder_nodes()
print("shape inferencing ...")
self.graph = SymbolicShapeInference.infer_shapes(
onnx_model, fixed_input_shape=self.fixed_input_shape)
print("shape inferenced.")
self.shape_inference()
self.build()
self.collect_value_infos()
self.allocate_shapes()
......@@ -152,7 +148,7 @@ class ONNXGraph(Graph):
"""
inner_nodes = []
if not isinstance(self.graph, onnx.GraphProto):
logger.error('graph is not a GraphProto instance')
assert 'graph is not a GraphProto instance'
return
for initializer in self.graph.initializer:
name = initializer.name
......@@ -168,13 +164,27 @@ class ONNXGraph(Graph):
shape.append(dim.dim_value)
return shape
def check_input_shape(self, vi):
def shape_inference(self):
print('shape inferencing ...')
infered_graph = SymbolicShapeInference.infer_shapes(
self.model, fixed_input_shape=self.fixed_input_shape)
if infered_graph is None:
infered_model = shape_inference.infer_shapes(self.model)
self.graph = infered_model.graph
else:
self.graph = infered_graph
print('shape inferenced.')
def is_static_shape(self, vi):
if vi.type.HasField('tensor_type'):
for dim in vi.type.tensor_type.shape.dim:
if dim.HasField(
'dim_param') and vi.name not in self.fixed_input_shape:
shape = self.get_symbolic_shape(
vi.type.tensor_type.shape.dim)
return False
return True
def fix_unkown_input_shape(self, vi):
shape = self.get_symbolic_shape(vi.type.tensor_type.shape.dim)
print(
"Unknown shape for input tensor[tensor name: '{}'] -> shape: {}, Please define shape of input here,\nNote:you can use visualization tools like Netron to check input shape."
.format(vi.name, shape))
......@@ -182,12 +192,10 @@ class ONNXGraph(Graph):
while not right_shape_been_input:
try:
shape = raw_input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: "
)
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: ")
except:
shape = input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: "
)
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: ")
if shape.count("-1") > 1:
print("Only 1 dimension can be -1, type again:)")
else:
......@@ -197,7 +205,6 @@ class ONNXGraph(Graph):
shape = [int(dim) for dim in shape.strip().split(',')]
assert shape.count(-1) <= 1, "Only one dimension can be -1"
self.fixed_input_shape[vi.name] = shape
break
def get_place_holder_nodes(self):
"""
......@@ -206,6 +213,7 @@ class ONNXGraph(Graph):
inner_nodes = self.get_inner_nodes()
for ipt_vi in self.graph.input:
if ipt_vi.name not in inner_nodes:
if self.define_input_shape:
self.check_input_shape(ipt_vi)
self.place_holder_nodes.append(ipt_vi.name)
......@@ -310,7 +318,7 @@ class ONNXGraph(Graph):
"""
if not isinstance(self.graph, onnx.GraphProto):
logger.error('graph is not a GraphProto instance')
assert 'graph is not a GraphProto instance'
return
for initializer in self.graph.initializer:
......@@ -353,7 +361,7 @@ class ONNXGraph(Graph):
class ONNXDecoder(object):
def __init__(self, onnx_model):
def __init__(self, onnx_model, define_input_shape=False):
onnx_model = onnx.load(onnx_model)
print('model ir_version: {}, op version: {}'.format(
onnx_model.ir_version, onnx_model.opset_import[0].version))
......@@ -364,7 +372,7 @@ class ONNXDecoder(object):
onnx_model = self.optimize_model_skip_op(onnx_model)
onnx_model = self.optimize_model_strip_initializer(onnx_model)
onnx_model = self.optimize_node_name(onnx_model)
self.graph = ONNXGraph(onnx_model)
self.graph = ONNXGraph(onnx_model, define_input_shape)
#self.onnx_model = onnx_model
def build_value_refs(self, nodes):
......
......@@ -1585,12 +1585,9 @@ class SymbolicShapeInference:
in_mp)
symbolic_shape_inference._update_output_from_vi()
if not all_shapes_inferred:
print('!' * 10)
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
symbolic_shape_inference.out_mp_)
#onnx.save(symbolic_shape_inference.out_mp_, 'tmp.onnx')
except:
print('Stopping at incomplete shape inference')
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
symbolic_shape_inference.out_mp_)
return None
return symbolic_shape_inference.out_mp_.graph
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册