提交 b012ffd0 编写于 作者: W wjj19950828

add input_shape_dict and rm raw_input

上级 5e5254cb
......@@ -115,6 +115,7 @@ x2paddle --framework=caffe --prototxt=deploy.prototxt --weight=deploy.caffemodel
| --weight | 当framework为caffe时,该参数指定caffe模型的参数文件路径 |
| --save_dir | 指定转换后的模型保存目录路径 |
| --model | 当framework为tensorflow/onnx时,该参数指定tensorflow的pb模型文件或onnx模型路径 |
| --input_shape_dict | **[可选]** For ONNX, 定义ONNX模型输入大小 |
| --caffe_proto | **[可选]** 由caffe.proto编译成caffe_pb2.py文件的存放路径,当存在自定义Layer时使用,默认为None |
| --define_input_shape | **[可选]** For TensorFlow, 当指定该参数时,强制用户输入每个Placeholder的shape,见[文档Q2](./docs/inference_model_convertor/FAQ.md) |
| --enable_code_optim | **[可选]** For PyTorch, 是否对生成代码进行优化,默认为False |
......
......@@ -73,6 +73,13 @@ def arg_parser():
action="store_true",
default=False,
help="define input shape for tf model")
parser.add_argument(
"--input_shape_dict",
"-isd",
type=_text_type,
default=None,
help="define input shapes, e.g --input_shape_dict=\"{'image':[1, 3, 608, 608]}\" or" \
"--input_shape_dict=\"{'image':[1, 3, 608, 608], 'im_shape': [1, 2], 'scale_factor': [1, 2]}\"")
parser.add_argument(
"--convert_torch_project",
"-tp",
......@@ -265,6 +272,7 @@ def caffe2paddle(proto_file,
def onnx2paddle(model_path,
save_dir,
input_shape_dict=None,
convert_to_lite=False,
lite_valid_places="arm",
lite_model_type="naive_buffer",
......@@ -292,7 +300,7 @@ def onnx2paddle(model_path,
from x2paddle.decoder.onnx_decoder import ONNXDecoder
from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper
model = ONNXDecoder(model_path, enable_onnx_checker)
model = ONNXDecoder(model_path, input_shape_dict, enable_onnx_checker)
mapper = ONNXOpMapper(model)
mapper.paddle_graph.build()
logging.info("Model optimizing ...")
......@@ -481,6 +489,7 @@ def main():
onnx2paddle(
args.model,
args.save_dir,
input_shape_dict=args.input_shape_dict,
convert_to_lite=args.to_lite,
lite_valid_places=args.lite_valid_places,
lite_model_type=args.lite_model_type,
......
......@@ -173,9 +173,12 @@ class ONNXGraphDataNode(GraphNode):
class ONNXGraph(Graph):
def __init__(self, onnx_model):
def __init__(self, onnx_model, input_shape_dict):
super(ONNXGraph, self).__init__(onnx_model)
self.fixed_input_shape = {}
if input_shape_dict is not None:
for k, v in eval(input_shape_dict).items():
self.fixed_input_shape["x2paddle_" + k] = v
self.initializer = {}
self.place_holder_nodes = list()
self.value_infos = {}
......@@ -216,37 +219,6 @@ class ONNXGraph(Graph):
shape.append(dim.dim_value)
return shape
def check_input_shape(self, vi):
if vi.type.HasField('tensor_type'):
for dim in vi.type.tensor_type.shape.dim:
if dim.HasField(
'dim_param') and vi.name not in self.fixed_input_shape:
shape = self.get_symbolic_shape(
vi.type.tensor_type.shape.dim)
print(
"Unknown shape for input tensor[tensor name: '{}'] -> shape: {}, Please define shape of input here,\nNote:you can use visualization tools like Netron to check input shape."
.format(vi.name, shape))
right_shape_been_input = False
while not right_shape_been_input:
try:
shape = raw_input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: "
)
except NameError:
shape = input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: "
)
if shape.count("-1") > 1:
print("Only 1 dimension can be -1, type again:)")
else:
right_shape_been_input = True
if shape == 'N':
break
shape = [int(dim) for dim in shape.strip().split(',')]
assert shape.count(-1) <= 1, "Only one dimension can be -1"
self.fixed_input_shape[vi.name] = shape
break
def get_place_holder_nodes(self):
"""
generate place_holder node of ONNX model
......@@ -254,7 +226,6 @@ class ONNXGraph(Graph):
inner_nodes = self.get_inner_nodes()
for ipt_vi in self.graph.input:
if ipt_vi.name not in inner_nodes:
self.check_input_shape(ipt_vi)
self.place_holder_nodes.append(ipt_vi.name)
def get_output_nodes(self):
......@@ -416,7 +387,7 @@ class ONNXGraph(Graph):
class ONNXDecoder(object):
def __init__(self, onnx_model, enable_onnx_checker):
def __init__(self, onnx_model, input_shape_dict, enable_onnx_checker):
onnx_model = onnx.load(onnx_model)
print('model ir_version: {}, op version: {}'.format(
onnx_model.ir_version, onnx_model.opset_import[0].version))
......@@ -427,7 +398,7 @@ class ONNXDecoder(object):
onnx_model = self.optimize_model_skip_op(onnx_model)
onnx_model = self.optimize_node_name(onnx_model)
self.graph = ONNXGraph(onnx_model)
self.graph = ONNXGraph(onnx_model, input_shape_dict)
def build_value_refs(self, nodes):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册