提交 b012ffd0 编写于 作者: W wjj19950828

add input_shape_dict and rm raw_input

上级 5e5254cb
...@@ -115,6 +115,7 @@ x2paddle --framework=caffe --prototxt=deploy.prototxt --weight=deploy.caffemodel ...@@ -115,6 +115,7 @@ x2paddle --framework=caffe --prototxt=deploy.prototxt --weight=deploy.caffemodel
| --weight | 当framework为caffe时,该参数指定caffe模型的参数文件路径 | | --weight | 当framework为caffe时,该参数指定caffe模型的参数文件路径 |
| --save_dir | 指定转换后的模型保存目录路径 | | --save_dir | 指定转换后的模型保存目录路径 |
| --model | 当framework为tensorflow/onnx时,该参数指定tensorflow的pb模型文件或onnx模型路径 | | --model | 当framework为tensorflow/onnx时,该参数指定tensorflow的pb模型文件或onnx模型路径 |
| --input_shape_dict | **[可选]** For ONNX, 定义ONNX模型输入大小 |
| --caffe_proto | **[可选]** 由caffe.proto编译成caffe_pb2.py文件的存放路径,当存在自定义Layer时使用,默认为None | | --caffe_proto | **[可选]** 由caffe.proto编译成caffe_pb2.py文件的存放路径,当存在自定义Layer时使用,默认为None |
| --define_input_shape | **[可选]** For TensorFlow, 当指定该参数时,强制用户输入每个Placeholder的shape,见[文档Q2](./docs/inference_model_convertor/FAQ.md) | | --define_input_shape | **[可选]** For TensorFlow, 当指定该参数时,强制用户输入每个Placeholder的shape,见[文档Q2](./docs/inference_model_convertor/FAQ.md) |
| --enable_code_optim | **[可选]** For PyTorch, 是否对生成代码进行优化,默认为False | | --enable_code_optim | **[可选]** For PyTorch, 是否对生成代码进行优化,默认为False |
......
...@@ -73,6 +73,13 @@ def arg_parser(): ...@@ -73,6 +73,13 @@ def arg_parser():
action="store_true", action="store_true",
default=False, default=False,
help="define input shape for tf model") help="define input shape for tf model")
parser.add_argument(
"--input_shape_dict",
"-isd",
type=_text_type,
default=None,
help="define input shapes, e.g --input_shape_dict=\"{'image':[1, 3, 608, 608]}\" or" \
"--input_shape_dict=\"{'image':[1, 3, 608, 608], 'im_shape': [1, 2], 'scale_factor': [1, 2]}\"")
parser.add_argument( parser.add_argument(
"--convert_torch_project", "--convert_torch_project",
"-tp", "-tp",
...@@ -265,6 +272,7 @@ def caffe2paddle(proto_file, ...@@ -265,6 +272,7 @@ def caffe2paddle(proto_file,
def onnx2paddle(model_path, def onnx2paddle(model_path,
save_dir, save_dir,
input_shape_dict=None,
convert_to_lite=False, convert_to_lite=False,
lite_valid_places="arm", lite_valid_places="arm",
lite_model_type="naive_buffer", lite_model_type="naive_buffer",
...@@ -292,7 +300,7 @@ def onnx2paddle(model_path, ...@@ -292,7 +300,7 @@ def onnx2paddle(model_path,
from x2paddle.decoder.onnx_decoder import ONNXDecoder from x2paddle.decoder.onnx_decoder import ONNXDecoder
from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper
model = ONNXDecoder(model_path, enable_onnx_checker) model = ONNXDecoder(model_path, input_shape_dict, enable_onnx_checker)
mapper = ONNXOpMapper(model) mapper = ONNXOpMapper(model)
mapper.paddle_graph.build() mapper.paddle_graph.build()
logging.info("Model optimizing ...") logging.info("Model optimizing ...")
...@@ -481,6 +489,7 @@ def main(): ...@@ -481,6 +489,7 @@ def main():
onnx2paddle( onnx2paddle(
args.model, args.model,
args.save_dir, args.save_dir,
input_shape_dict=args.input_shape_dict,
convert_to_lite=args.to_lite, convert_to_lite=args.to_lite,
lite_valid_places=args.lite_valid_places, lite_valid_places=args.lite_valid_places,
lite_model_type=args.lite_model_type, lite_model_type=args.lite_model_type,
......
...@@ -173,9 +173,12 @@ class ONNXGraphDataNode(GraphNode): ...@@ -173,9 +173,12 @@ class ONNXGraphDataNode(GraphNode):
class ONNXGraph(Graph): class ONNXGraph(Graph):
def __init__(self, onnx_model): def __init__(self, onnx_model, input_shape_dict):
super(ONNXGraph, self).__init__(onnx_model) super(ONNXGraph, self).__init__(onnx_model)
self.fixed_input_shape = {} self.fixed_input_shape = {}
if input_shape_dict is not None:
for k, v in eval(input_shape_dict).items():
self.fixed_input_shape["x2paddle_" + k] = v
self.initializer = {} self.initializer = {}
self.place_holder_nodes = list() self.place_holder_nodes = list()
self.value_infos = {} self.value_infos = {}
...@@ -216,37 +219,6 @@ class ONNXGraph(Graph): ...@@ -216,37 +219,6 @@ class ONNXGraph(Graph):
shape.append(dim.dim_value) shape.append(dim.dim_value)
return shape return shape
def check_input_shape(self, vi):
if vi.type.HasField('tensor_type'):
for dim in vi.type.tensor_type.shape.dim:
if dim.HasField(
'dim_param') and vi.name not in self.fixed_input_shape:
shape = self.get_symbolic_shape(
vi.type.tensor_type.shape.dim)
print(
"Unknown shape for input tensor[tensor name: '{}'] -> shape: {}, Please define shape of input here,\nNote:you can use visualization tools like Netron to check input shape."
.format(vi.name, shape))
right_shape_been_input = False
while not right_shape_been_input:
try:
shape = raw_input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: "
)
except NameError:
shape = input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: "
)
if shape.count("-1") > 1:
print("Only 1 dimension can be -1, type again:)")
else:
right_shape_been_input = True
if shape == 'N':
break
shape = [int(dim) for dim in shape.strip().split(',')]
assert shape.count(-1) <= 1, "Only one dimension can be -1"
self.fixed_input_shape[vi.name] = shape
break
def get_place_holder_nodes(self): def get_place_holder_nodes(self):
""" """
generate place_holder node of ONNX model generate place_holder node of ONNX model
...@@ -254,7 +226,6 @@ class ONNXGraph(Graph): ...@@ -254,7 +226,6 @@ class ONNXGraph(Graph):
inner_nodes = self.get_inner_nodes() inner_nodes = self.get_inner_nodes()
for ipt_vi in self.graph.input: for ipt_vi in self.graph.input:
if ipt_vi.name not in inner_nodes: if ipt_vi.name not in inner_nodes:
self.check_input_shape(ipt_vi)
self.place_holder_nodes.append(ipt_vi.name) self.place_holder_nodes.append(ipt_vi.name)
def get_output_nodes(self): def get_output_nodes(self):
...@@ -416,7 +387,7 @@ class ONNXGraph(Graph): ...@@ -416,7 +387,7 @@ class ONNXGraph(Graph):
class ONNXDecoder(object): class ONNXDecoder(object):
def __init__(self, onnx_model, enable_onnx_checker): def __init__(self, onnx_model, input_shape_dict, enable_onnx_checker):
onnx_model = onnx.load(onnx_model) onnx_model = onnx.load(onnx_model)
print('model ir_version: {}, op version: {}'.format( print('model ir_version: {}, op version: {}'.format(
onnx_model.ir_version, onnx_model.opset_import[0].version)) onnx_model.ir_version, onnx_model.opset_import[0].version))
...@@ -427,7 +398,7 @@ class ONNXDecoder(object): ...@@ -427,7 +398,7 @@ class ONNXDecoder(object):
onnx_model = self.optimize_model_skip_op(onnx_model) onnx_model = self.optimize_model_skip_op(onnx_model)
onnx_model = self.optimize_node_name(onnx_model) onnx_model = self.optimize_node_name(onnx_model)
self.graph = ONNXGraph(onnx_model) self.graph = ONNXGraph(onnx_model, input_shape_dict)
def build_value_refs(self, nodes): def build_value_refs(self, nodes):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册