提交 4f51c4d0 编写于 作者: C Channingss

support export to onnx

上级 aafd6fe4
...@@ -60,25 +60,35 @@ def main(): ...@@ -60,25 +60,35 @@ def main():
print("Repo: https://github.com/PaddlePaddle/PaddleX.git") print("Repo: https://github.com/PaddlePaddle/PaddleX.git")
print("Email: paddlex@baidu.com") print("Email: paddlex@baidu.com")
return return
if args.export_inference: if args.export_inference:
assert args.model_dir is not None, "--model_dir should be defined while exporting inference model" assert args.model_dir is not None, "--model_dir should be defined while exporting inference model"
assert args.save_dir is not None, "--save_dir should be defined to save inference model" assert args.save_dir is not None, "--save_dir should be defined to save inference model"
fixed_input_shape = eval(args.fixed_input_shape)
assert len( fixed_input_shape = None
fixed_input_shape) == 2, "len of fixed input shape must == 2" if args.fixed_input_shape is not None:
fixed_input_shape = eval(args.fixed_input_shape)
assert len(
fixed_input_shape) == 2, "len of fixed input shape must == 2"
model = pdx.load_model(args.model_dir, fixed_input_shape) model = pdx.load_model(args.model_dir, fixed_input_shape)
model.export_inference_model(args.save_dir) model.export_inference_model(args.save_dir)
# if args.export_onnx: if args.export_onnx:
# assert args.model_dir is not None, "--model_dir should be defined while exporting onnx model" assert args.model_dir is not None, "--model_dir should be defined while exporting onnx model"
# assert args.save_dir is not None, "--save_dir should be defined to save onnx model" assert args.save_dir is not None, "--save_dir should be defined to save onnx model"
# fixed_input_shape = eval(args.fixed_input_shape)
# assert len( fixed_input_shape = None
# fixed_input_shape) == 2, "len of fixed input shape must == 2" if args.fixed_input_shape is not None:
fixed_input_shape = eval(args.fixed_input_shape)
assert len(
fixed_input_shape) == 2, "len of fixed input shape must == 2"
model = pdx.load_model(args.model_dir, fixed_input_shape)
# model = pdx.load_model(args.model_dir, fixed_input_shape) model_name = os.path.basename(args.model_dir.strip('/')).split('/')[-1]
# model.export_onnx_model(args.save_dir) onnx_name = model_name + '.onnx'
model.export_onnx_model(args.save_dir, onnx_name=onnx_name)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import absolute_import from __future__ import absolute_import
import paddle.fluid as fluid import paddle.fluid as fluid
import os import os
import sys
import numpy as np import numpy as np
import time import time
import math import math
...@@ -327,129 +328,119 @@ class BaseAPI: ...@@ -327,129 +328,119 @@ class BaseAPI:
logging.info( logging.info(
"Model for inference deploy saved in {}.".format(save_dir)) "Model for inference deploy saved in {}.".format(save_dir))
# def export_onnx_model(self, save_dir, onnx_model=None): def export_onnx_model(self, save_dir, onnx_name=None):
# from fluid.utils import op_io_info, init_name_prefix from fluid.utils import op_io_info, init_name_prefix
# from onnx import helper, checker from onnx import helper, checker
# import fluid_onnx.ops as ops import fluid_onnx.ops as ops
# from fluid_onnx.variables import paddle_variable_to_onnx_tensor, paddle_onnx_weight from fluid_onnx.variables import paddle_variable_to_onnx_tensor, paddle_onnx_weight
# from debug.model_check import debug_model, Tracke from debug.model_check import debug_model, Tracker
# place = fluid.CPUPlace() place = fluid.CPUPlace()
# exe = fluid.Executor(place) exe = fluid.Executor(place)
# inference_scope = fluid.core.Scope() inference_scope = fluid.global_scope()
# with fluid.scope_guard(inference_scope): with fluid.scope_guard(inference_scope):
# test_input_names = [ test_input_names = [
# var.name for var in list(self.test_inputs.values()) var.name for var in list(self.test_inputs.values())
# ] ]
# inputs_outputs_list = ["fetch", "feed"] inputs_outputs_list = ["fetch", "feed"]
# weights, weights_value_info = [], [] weights, weights_value_info = [], []
# global_block = self.test_program.global_block() global_block = self.test_prog.global_block()
# for var_name in global_block.vars: for var_name in global_block.vars:
# var = global_block.var(var_name) var = global_block.var(var_name)
# if var_name not in feed_fetch_list\ if var_name not in test_input_names\
# and var.persistable: and var.persistable:
# weight, val_info = paddle_onnx_weight( weight, val_info = paddle_onnx_weight(
# var=var, scope=inference_scope) var=var, scope=inference_scope)
# weights.append(weight) weights.append(weight)
# weights_value_info.append(val_info) weights_value_info.append(val_info)
# # Create inputs # Create inputs
# inputs = [ inputs = [
# paddle_variable_to_onnx_tensor(v, global_block) paddle_variable_to_onnx_tensor(v, global_block)
# for v in test_input_names for v in test_input_names
# ] ]
# print("load the model parameter done.") print("load the model parameter done.")
# onnx_nodes = [] onnx_nodes = []
# op_check_list = [] op_check_list = []
# op_trackers = [] op_trackers = []
# nms_first_index = -1 nms_first_index = -1
# nms_outputs = [] nms_outputs = []
# for block in inference_program.blocks: for block in self.test_prog.blocks:
# for op in block.ops: for op in block.ops:
# if op.type in ops.node_maker: if op.type in ops.node_maker:
# # TODO(kuke): deal with the corner case that vars in # TODO(kuke): deal with the corner case that vars in
# # different blocks have the same name # different blocks have the same name
# node_proto = ops.node_maker[str(op.type)](operator=op, node_proto = ops.node_maker[str(op.type)](
# block=block) operator=op, block=block)
# op_outputs = [] op_outputs = []
# last_node = None last_node = None
# if isinstance(node_proto, tuple): if isinstance(node_proto, tuple):
# onnx_nodes.extend(list(node_proto)) onnx_nodes.extend(list(node_proto))
# last_node = list(node_proto) last_node = list(node_proto)
# else: else:
# onnx_nodes.append(node_proto) onnx_nodes.append(node_proto)
# last_node = [node_proto] last_node = [node_proto]
# tracker = Tracker(str(op.type), last_node) tracker = Tracker(str(op.type), last_node)
# op_trackers.append(tracker) op_trackers.append(tracker)
# op_check_list.append(str(op.type)) op_check_list.append(str(op.type))
# if op.type == "multiclass_nms" and nms_first_index < 0: if op.type == "multiclass_nms" and nms_first_index < 0:
# nms_first_index = 0 nms_first_index = 0
# if nms_first_index >= 0: if nms_first_index >= 0:
# _, _, output_op = op_io_info(op) _, _, output_op = op_io_info(op)
# for output in output_op: for output in output_op:
# nms_outputs.extend(output_op[output]) nms_outputs.extend(output_op[output])
# else: else:
# if op.type not in ['feed', 'fetch']: if op.type not in ['feed', 'fetch']:
# op_check_list.append(op.type) op_check_list.append(op.type)
# print('The operator sets to run test case.') print('The operator sets to run test case.')
# print(set(op_check_list)) print(set(op_check_list))
# # Create outputs # Create outputs
# # Get the new names for outputs if they've been renamed in nodes' making # Get the new names for outputs if they've been renamed in nodes' making
# renamed_outputs = op_io_info.get_all_renamed_outputs() renamed_outputs = op_io_info.get_all_renamed_outputs()
# test_outputs = list(self.test_outputs.values()) test_outputs = list(self.test_outputs.values())
# test_outputs_names = [var.name for var in self.test_outpus.values] test_outputs_names = [
# test_outputs_names = [ var.name for var in self.test_outputs.values()
# name if name not in renamed_outputs else renamed_outputs[name] ]
# for name in test_outputs_names test_outputs_names = [
# ] name if name not in renamed_outputs else renamed_outputs[name]
# outputs = [ for name in test_outputs_names
# paddle_variable_to_onnx_tensor(v, global_block) ]
# for v in test_outputs_names outputs = [
# ] paddle_variable_to_onnx_tensor(v, global_block)
# # Make graph for v in test_outputs_names
# #model_name = os.path.basename(args.fluid_model.strip('/')).split('.')[0] ]
# model_name = 'test' # Make graph
# onnx_graph = helper.make_graph( onnx_name = 'test'
# nodes=onnx_nodes, onnx_graph = helper.make_graph(
# name=model_name, nodes=onnx_nodes,
# initializer=weights, name=onnx_name,
# inputs=inputs + weights_value_info, initializer=weights,
# outputs=outputs) inputs=inputs + weights_value_info,
outputs=outputs)
# # Make model
# onnx_model = helper.make_model(onnx_graph, producer_name='PaddlePaddle') # Make model
onnx_model = helper.make_model(
# # Model check onnx_graph, producer_name='PaddlePaddle')
# checker.check_model(onnx_model)
# Model check
# # Print model checker.check_model(onnx_model)
# #if to_print_model:
# # print("The converted model is:\n{}".format(onnx_model)) # Print model
# # Save converted model #if to_print_model:
# # print("The converted model is:\n{}".format(onnx_model))
# if onnx_model is not None: # Save converted model
# try:
# onnx_model_file = osp.join(save_dir, onnx_model) if onnx_model is not None:
# with open(onnx_model_file, 'wb') as f: try:
# f.write(onnx_model.SerializeToString()) onnx_model_file = osp.join(save_dir, onnx_name)
# print("Saved converted model to path: %s" % onnx_model_file) with open(onnx_model_file, 'wb') as f:
# # If in debug mode, need to save op list, add we will check op f.write(onnx_model.SerializeToString())
# #if args.debug: print(
# # op_check_list = list(set(op_check_list)) "Saved converted model to path: %s" % onnx_model_file)
# # check_outputs = [] except Exception as e:
print(e)
# # for node_proto in onnx_nodes: print(
# # check_outputs.extend(node_proto.output) "Convert Failed! Please use the debug message to find error."
)
# # print("The num of %d operators need to check, and %d op outputs need to check."\ sys.exit(-1)
# # %(len(op_check_list), len(check_outputs)))
# # debug_model(op_check_list, op_trackers, nms_outputs, args)
# except Exception as e:
# print(e)
# print(
# "Convert Failed! Please use the debug message to find error."
# )
# sys.exit(-1)
def train_loop(self, def train_loop(self,
num_epochs, num_epochs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册