提交 2ec20d64 编写于 作者: C Channingss

support for expand, fix bug of inference by onnxruntime

上级 1db5fcf7
...@@ -44,8 +44,6 @@ class ONNXGraphNode(GraphNode): ...@@ -44,8 +44,6 @@ class ONNXGraphNode(GraphNode):
self.layer_type = layer.op_type self.layer_type = layer.op_type
self.fluid_code = FluidCode() self.fluid_code = FluidCode()
self.attr_map = self.get_attr_map() self.attr_map = self.get_attr_map()
self.dtype_map = {1: "float32", 3: "int32", 9: "int64"}
self.weight_inputs = list()
self.out_shapes = list() self.out_shapes = list()
self.dtype = None self.dtype = None
self.which_child = {} self.which_child = {}
...@@ -206,7 +204,20 @@ class ONNXGraph(Graph): ...@@ -206,7 +204,20 @@ class ONNXGraph(Graph):
#generate connection between nodes for topo #generate connection between nodes for topo
for layer_name, node in self.node_map.items(): for layer_name, node in self.node_map.items():
if isinstance(node, ONNXGraphNode): if isinstance(node, ONNXGraphNode):
self.build_connection(layer_name, node)
#generate topo
super(ONNXGraph, self).build()
self.input_nodes = self.place_holder_nodes
def build_connection(self, layer_name, node):
"""
find connection for nodes
"""
for idx, in_node in enumerate(node.layer.input): for idx, in_node in enumerate(node.layer.input):
if in_node == '':
continue
if in_node not in self.node_map: if in_node not in self.node_map:
flag = 0 flag = 0
for nd in self.model.node: for nd in self.model.node:
...@@ -221,14 +232,10 @@ class ONNXGraph(Graph): ...@@ -221,14 +232,10 @@ class ONNXGraph(Graph):
break break
if flag == 0: if flag == 0:
raise Exception( raise Exception(
'input[{}] of node[{}] does not exist in node_map' 'input[{}] of node[{}] does not exist in node_map'.
.format(in_node, layer_name)) format(in_node, layer_name))
else: else:
self.connect(in_node, layer_name) self.connect(in_node, layer_name)
#generate topo
super(ONNXGraph, self).build()
self.input_nodes = self.place_holder_nodes
def get_input_node(self, node, idx=0, copy=False): def get_input_node(self, node, idx=0, copy=False):
if len(node.which_child) == 0: if len(node.which_child) == 0:
...@@ -450,7 +457,6 @@ class ONNXDecoder(object): ...@@ -450,7 +457,6 @@ class ONNXDecoder(object):
""" """
make a valid code name for ParamAttr make a valid code name for ParamAttr
""" """
if name == '': if name == '':
raise ValueError('name should not be empty') raise ValueError('name should not be empty')
for s in ' .*?\\/-:': for s in ' .*?\\/-:':
...@@ -473,6 +479,9 @@ class ONNXDecoder(object): ...@@ -473,6 +479,9 @@ class ONNXDecoder(object):
node.name = node.output[0] node.name = node.output[0]
node.name = self.make_variable_name(node.name) node.name = self.make_variable_name(node.name)
for i in range(len(node.input)): for i in range(len(node.input)):
if node.input[i] == '':
continue
else:
node.input[i] = self.make_variable_name(node.input[i]) node.input[i] = self.make_variable_name(node.input[i])
for i in range(len(node.output)): for i in range(len(node.output)):
node.output[i] = self.make_variable_name(node.output[i]) node.output[i] = self.make_variable_name(node.output[i])
...@@ -34,16 +34,14 @@ def main(): ...@@ -34,16 +34,14 @@ def main():
save_dir = args.save_dir save_dir = args.save_dir
model_dir = os.path.join(save_dir, 'onnx_model_infer.onnx') model_dir = os.path.join(save_dir, 'onnx_model_infer.onnx')
data_dir = os.path.join(save_dir, 'input_data.npy')
model = onnx.load(model_dir) model = onnx.load(model_dir)
sess = rt.InferenceSession(model_dir) sess = rt.InferenceSession(model_dir)
inputs = np.load(data_dir, allow_pickle=True)
data_dir
inputs_dict = {} inputs_dict = {}
for i, ipt in enumerate(inputs): for ipt in sess.get_inputs():
inputs_dict[sess.get_inputs()[i].name] = ipt data_dir = os.path.join(save_dir, ipt.name + '.npy')
inputs_dict[ipt.name] = np.load(data_dir, allow_pickle=True)
res = sess.run(None, input_feed=inputs_dict) res = sess.run(None, input_feed=inputs_dict)
for idx, value_info in enumerate(model.graph.output): for idx, value_info in enumerate(model.graph.output):
np.save(os.path.join(save_dir, value_info.name), res[idx]) np.save(os.path.join(save_dir, value_info.name), res[idx])
......
...@@ -26,8 +26,6 @@ default_op_mapping_field_values['OUTPUT_PERM'] = None ...@@ -26,8 +26,6 @@ default_op_mapping_field_values['OUTPUT_PERM'] = None
default_op_mapping_field_values['FILL_NAME_FIELD'] = True default_op_mapping_field_values['FILL_NAME_FIELD'] = True
default_op_mapping = { default_op_mapping = {
'Gather': ['gather', ['X'], ['Out'],
dict(axis='')],
'Shape': ['shape', ['X'], ['Out']], 'Shape': ['shape', ['X'], ['Out']],
'Clip': [ 'Clip': [
'clip', ['X'], ['Out'], 'clip', ['X'], ['Out'],
...@@ -81,11 +79,6 @@ default_op_mapping = { ...@@ -81,11 +79,6 @@ default_op_mapping = {
'Sqrt': ['sqrt', ['X'], ['Out']], 'Sqrt': ['sqrt', ['X'], ['Out']],
} }
activefunc_op_mapping = {
'LeakyRelu': ['leaky_relu', ['X'], ['Out'],
dict(), dict(alpha=.01)],
}
default_ioa_constraint = { default_ioa_constraint = {
'Gather': 'Gather':
[(lambda i, o, a: a.get('axis', 0) == 0, 'only axis = 0 is supported')], [(lambda i, o, a: a.get('axis', 0) == 0, 'only axis = 0 is supported')],
......
...@@ -116,12 +116,14 @@ class ONNXOpMapper(OpMapper): ...@@ -116,12 +116,14 @@ class ONNXOpMapper(OpMapper):
return False return False
def get_results_of_inference(self, model, value_infos, data_nodes): def get_results_of_inference(self, model, value_infos, data_nodes):
inputs = [] if not os.path.exists(self.tmp_data_dir):
os.makedirs(self.tmp_data_dir)
for data_node in data_nodes: for data_node in data_nodes:
value_info = value_infos[data_node] value_info = value_infos[data_node]
ipt = np.random.random(value_info['shape']).astype( ipt = np.random.random(value_info['shape']).astype(
value_info['dtype']) value_info['dtype'])
inputs.append(ipt) np.save(os.path.join(self.tmp_data_dir, data_node), ipt)
model = onnx.shape_inference.infer_shapes(model) model = onnx.shape_inference.infer_shapes(model)
outputs = [] outputs = []
...@@ -130,11 +132,8 @@ class ONNXOpMapper(OpMapper): ...@@ -130,11 +132,8 @@ class ONNXOpMapper(OpMapper):
model.graph.ClearField('output') model.graph.ClearField('output')
model.graph.output.MergeFrom(outputs) model.graph.output.MergeFrom(outputs)
if not os.path.exists(self.tmp_data_dir):
os.makedirs(self.tmp_data_dir)
onnx.save(model, os.path.join(self.tmp_data_dir, onnx.save(model, os.path.join(self.tmp_data_dir,
'onnx_model_infer.onnx')) 'onnx_model_infer.onnx'))
np.save(os.path.join(self.tmp_data_dir, 'input_data.npy'), inputs)
os.system('onnx_infer --save_dir=' + self.tmp_data_dir) os.system('onnx_infer --save_dir=' + self.tmp_data_dir)
return return
...@@ -457,7 +456,6 @@ class ONNXOpMapper(OpMapper): ...@@ -457,7 +456,6 @@ class ONNXOpMapper(OpMapper):
def Unsqueeze(self, node): def Unsqueeze(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes') axes = node.get_attr('axes')
if len(val_x.out_shapes[0]) == 0: if len(val_x.out_shapes[0]) == 0:
node.fluid_code.add_layer('assign', node.fluid_code.add_layer('assign',
inputs=val_x, inputs=val_x,
...@@ -491,6 +489,7 @@ class ONNXOpMapper(OpMapper): ...@@ -491,6 +489,7 @@ class ONNXOpMapper(OpMapper):
assert dtype == output_dtype, 'tensor dtype unmatches storage dtype' assert dtype == output_dtype, 'tensor dtype unmatches storage dtype'
shape = node.get_attr('shape', None) shape = node.get_attr('shape', None)
if shape is None: if shape is None:
shape = val_output.out_shapes[0] shape = val_output.out_shapes[0]
if shape is None: if shape is None:
...@@ -536,11 +535,16 @@ class ONNXOpMapper(OpMapper): ...@@ -536,11 +535,16 @@ class ONNXOpMapper(OpMapper):
def Expand(self, node): def Expand(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_shape = self.graph.get_input_node(node, idx=1, copy=True) val_shape = self.graph.get_input_node(node, idx=1, copy=True)
if len(val_shape.outputs) == 1:
self.omit_nodes.append(val_shape.layer_name)
val_y = self.graph.get_node(node.layer.output[0], copy=True) val_y = self.graph.get_node(node.layer.output[0], copy=True)
out_shape = node.out_shapes[0] out_shape = node.out_shapes[0]
val_x_dtype = val_x.dtype
name_ones = node.layer_name + '_ones' name_ones = node.layer_name + '_ones'
attr_ones = {'shape': out_shape, 'dtype': string('int64')} attr_ones = {'shape': out_shape, 'dtype': string(val_x_dtype)}
node.fluid_code.add_layer('ones', node.fluid_code.add_layer('ones',
inputs=None, inputs=None,
output=name_ones, output=name_ones,
...@@ -724,8 +728,16 @@ class ONNXOpMapper(OpMapper): ...@@ -724,8 +728,16 @@ class ONNXOpMapper(OpMapper):
# catch dynamic graph shape # catch dynamic graph shape
if isinstance(val_shape, ONNXGraphNode): if isinstance(val_shape, ONNXGraphNode):
shape, _, _ = self.get_dynamic_shape(val_shape.layer_name) shape, _, _ = self.get_dynamic_shape(val_shape.layer_name)
attr['actual_shape'] = val_shape.layer_name if val_shape.dtype == 'int64':
val_shape_cast = val_shape.layer_name + '_cast'
node.fluid_code.add_layer('cast',
inputs=val_shape,
output=val_shape_cast,
param_attr={'dtype': string('int32')})
attr['actual_shape'] = val_shape_cast
else:
attr['actual_shape'] = val_shape
if shape is None: if shape is None:
shape = val_reshaped.out_shapes[0] shape = val_reshaped.out_shapes[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册