提交 d56d4637 编写于 作者: J jiangjiajun

add x2paddle for paddle export to onnx

上级 5c04b3e6
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -15,6 +15,7 @@
from six import text_type as _text_type
import argparse
import sys
import paddlex.utils.logging as logging
def arg_parser():
......@@ -94,15 +95,15 @@ def main():
if args.export_onnx:
assert args.model_dir is not None, "--model_dir should be defined while exporting onnx model"
assert args.save_dir is not None, "--save_dir should be defined to create onnx model"
assert args.fixed_input_shape is not None, "--fixed_input_shape should be defined [w,h] to create onnx model, such as [224,224]"
fixed_input_shape = []
if args.fixed_input_shape is not None:
fixed_input_shape = eval(args.fixed_input_shape)
assert len(
fixed_input_shape
) == 2, "len of fixed input shape must == 2, such as [224,224]"
model = pdx.load_model(args.model_dir, fixed_input_shape)
model = pdx.load_model(args.model_dir)
if model.status == "Normal" or model.status == "Prune":
logging.error(
"Only support inference model, try to export model first as below,",
exit=False)
logging.error(
"paddlex --export_inference --model_dir model_path --save_dir infer_model"
)
pdx.convertor.export_onnx_model(model, args.save_dir)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -30,119 +30,13 @@ def export_onnx(model_dir, save_dir, fixed_input_shape):
def export_onnx_model(model, save_dir):
support_list = [
'ResNet18', 'ResNet34', 'ResNet50', 'ResNet101', 'ResNet50_vd',
'ResNet101_vd', 'ResNet50_vd_ssld', 'ResNet101_vd_ssld', 'DarkNet53',
'MobileNetV1', 'MobileNetV2', 'DenseNet121', 'DenseNet161',
'DenseNet201'
]
if model.__class__.__name__ not in support_list:
raise Exception("Model: {} unsupport export to ONNX".format(
model.__class__.__name__))
try:
from fluid.utils import op_io_info, init_name_prefix
from onnx import helper, checker
import fluid_onnx.ops as ops
from fluid_onnx.variables import paddle_variable_to_onnx_tensor, paddle_onnx_weight
from debug.model_check import debug_model, Tracker
except Exception as e:
import x2paddle
if x2paddle.__version__ < '0.7.4':
logging.error("You need to upgrade x2paddle >= 0.7.4")
except:
logging.error(
"Import Module Failed! Please install paddle2onnx. Related requirements see https://github.com/PaddlePaddle/paddle2onnx."
)
raise e
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.global_scope()
with fluid.scope_guard(inference_scope):
test_input_names = [
var.name for var in list(model.test_inputs.values())
]
inputs_outputs_list = ["fetch", "feed"]
weights, weights_value_info = [], []
global_block = model.test_prog.global_block()
for var_name in global_block.vars:
var = global_block.var(var_name)
if var_name not in test_input_names\
and var.persistable:
weight, val_info = paddle_onnx_weight(
var=var, scope=inference_scope)
weights.append(weight)
weights_value_info.append(val_info)
# Create inputs
inputs = [
paddle_variable_to_onnx_tensor(v, global_block)
for v in test_input_names
]
logging.INFO("load the model parameter done.")
onnx_nodes = []
op_check_list = []
op_trackers = []
nms_first_index = -1
nms_outputs = []
for block in model.test_prog.blocks:
for op in block.ops:
if op.type in ops.node_maker:
# TODO: deal with the corner case that vars in
# different blocks have the same name
node_proto = ops.node_maker[str(op.type)](
operator=op, block=block)
op_outputs = []
last_node = None
if isinstance(node_proto, tuple):
onnx_nodes.extend(list(node_proto))
last_node = list(node_proto)
else:
onnx_nodes.append(node_proto)
last_node = [node_proto]
tracker = Tracker(str(op.type), last_node)
op_trackers.append(tracker)
op_check_list.append(str(op.type))
if op.type == "multiclass_nms" and nms_first_index < 0:
nms_first_index = 0
if nms_first_index >= 0:
_, _, output_op = op_io_info(op)
for output in output_op:
nms_outputs.extend(output_op[output])
else:
if op.type not in ['feed', 'fetch']:
op_check_list.append(op.type)
logging.info('The operator sets to run test case.')
logging.info(set(op_check_list))
# Create outputs
# Get the new names for outputs if they've been renamed in nodes' making
renamed_outputs = op_io_info.get_all_renamed_outputs()
test_outputs = list(model.test_outputs.values())
test_outputs_names = [var.name for var in model.test_outputs.values()]
test_outputs_names = [
name if name not in renamed_outputs else renamed_outputs[name]
for name in test_outputs_names
]
outputs = [
paddle_variable_to_onnx_tensor(v, global_block)
for v in test_outputs_names
]
# Make graph
onnx_name = 'paddlex.onnx'
onnx_graph = helper.make_graph(
nodes=onnx_nodes,
name=onnx_name,
initializer=weights,
inputs=inputs + weights_value_info,
outputs=outputs)
# Make model
onnx_model = helper.make_model(
onnx_graph, producer_name='PaddlePaddle')
# Model check
checker.check_model(onnx_model)
if onnx_model is not None:
onnx_model_file = os.path.join(save_dir, onnx_name)
if not os.path.exists(save_dir):
os.mkdir(save_dir)
with open(onnx_model_file, 'wb') as f:
f.write(onnx_model.SerializeToString())
logging.info("Saved converted model to path: %s" % onnx_model_file)
"You need to install x2paddle first, pip install x2paddle>=0.7.4")
from x2paddle.op_mapper.paddle_op_mapper import PaddleOpMapper
mapper = PaddleOpMapper()
mapper.convert(model.test_prog, save_dir)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册