提交 2ec20d64 编写于 作者: C Channingss

support for expand, fix bug of inference by onnxruntime

上级 1db5fcf7
......@@ -44,8 +44,6 @@ class ONNXGraphNode(GraphNode):
self.layer_type = layer.op_type
self.fluid_code = FluidCode()
self.attr_map = self.get_attr_map()
self.dtype_map = {1: "float32", 3: "int32", 9: "int64"}
self.weight_inputs = list()
self.out_shapes = list()
self.dtype = None
self.which_child = {}
......@@ -206,30 +204,39 @@ class ONNXGraph(Graph):
#generate connection between nodes for topo
for layer_name, node in self.node_map.items():
if isinstance(node, ONNXGraphNode):
for idx, in_node in enumerate(node.layer.input):
if in_node not in self.node_map:
flag = 0
for nd in self.model.node:
for idx, opt in enumerate(nd.output):
if opt == in_node:
self.connect(nd.name, layer_name)
flag = 1
node.which_child[nd.name] = idx
self.node_map[nd.name].index = 0
break
if flag == 1:
break
if flag == 0:
raise Exception(
'input[{}] of node[{}] does not exist in node_map'
.format(in_node, layer_name))
else:
self.connect(in_node, layer_name)
self.build_connection(layer_name, node)
#generate topo
super(ONNXGraph, self).build()
self.input_nodes = self.place_holder_nodes
def build_connection(self, layer_name, node):
"""
find connection for nodes
"""
for idx, in_node in enumerate(node.layer.input):
if in_node == '':
continue
if in_node not in self.node_map:
flag = 0
for nd in self.model.node:
for idx, opt in enumerate(nd.output):
if opt == in_node:
self.connect(nd.name, layer_name)
flag = 1
node.which_child[nd.name] = idx
self.node_map[nd.name].index = 0
break
if flag == 1:
break
if flag == 0:
raise Exception(
'input[{}] of node[{}] does not exist in node_map'.
format(in_node, layer_name))
else:
self.connect(in_node, layer_name)
def get_input_node(self, node, idx=0, copy=False):
if len(node.which_child) == 0:
ipt_node = super(ONNXGraph, self).get_node(node.inputs[idx], copy)
......@@ -450,7 +457,6 @@ class ONNXDecoder(object):
"""
make a valid code name for ParamAttr
"""
if name == '':
raise ValueError('name should not be empty')
for s in ' .*?\\/-:':
......@@ -473,6 +479,9 @@ class ONNXDecoder(object):
node.name = node.output[0]
node.name = self.make_variable_name(node.name)
for i in range(len(node.input)):
node.input[i] = self.make_variable_name(node.input[i])
if node.input[i] == '':
continue
else:
node.input[i] = self.make_variable_name(node.input[i])
for i in range(len(node.output)):
node.output[i] = self.make_variable_name(node.output[i])
......@@ -34,16 +34,14 @@ def main():
save_dir = args.save_dir
model_dir = os.path.join(save_dir, 'onnx_model_infer.onnx')
data_dir = os.path.join(save_dir, 'input_data.npy')
model = onnx.load(model_dir)
sess = rt.InferenceSession(model_dir)
inputs = np.load(data_dir, allow_pickle=True)
data_dir
inputs_dict = {}
for i, ipt in enumerate(inputs):
inputs_dict[sess.get_inputs()[i].name] = ipt
for ipt in sess.get_inputs():
data_dir = os.path.join(save_dir, ipt.name + '.npy')
inputs_dict[ipt.name] = np.load(data_dir, allow_pickle=True)
res = sess.run(None, input_feed=inputs_dict)
for idx, value_info in enumerate(model.graph.output):
np.save(os.path.join(save_dir, value_info.name), res[idx])
......
......@@ -26,8 +26,6 @@ default_op_mapping_field_values['OUTPUT_PERM'] = None
default_op_mapping_field_values['FILL_NAME_FIELD'] = True
default_op_mapping = {
'Gather': ['gather', ['X'], ['Out'],
dict(axis='')],
'Shape': ['shape', ['X'], ['Out']],
'Clip': [
'clip', ['X'], ['Out'],
......@@ -81,11 +79,6 @@ default_op_mapping = {
'Sqrt': ['sqrt', ['X'], ['Out']],
}
activefunc_op_mapping = {
'LeakyRelu': ['leaky_relu', ['X'], ['Out'],
dict(), dict(alpha=.01)],
}
default_ioa_constraint = {
'Gather':
[(lambda i, o, a: a.get('axis', 0) == 0, 'only axis = 0 is supported')],
......
......@@ -116,12 +116,14 @@ class ONNXOpMapper(OpMapper):
return False
def get_results_of_inference(self, model, value_infos, data_nodes):
inputs = []
if not os.path.exists(self.tmp_data_dir):
os.makedirs(self.tmp_data_dir)
for data_node in data_nodes:
value_info = value_infos[data_node]
ipt = np.random.random(value_info['shape']).astype(
value_info['dtype'])
inputs.append(ipt)
np.save(os.path.join(self.tmp_data_dir, data_node), ipt)
model = onnx.shape_inference.infer_shapes(model)
outputs = []
......@@ -130,11 +132,8 @@ class ONNXOpMapper(OpMapper):
model.graph.ClearField('output')
model.graph.output.MergeFrom(outputs)
if not os.path.exists(self.tmp_data_dir):
os.makedirs(self.tmp_data_dir)
onnx.save(model, os.path.join(self.tmp_data_dir,
'onnx_model_infer.onnx'))
np.save(os.path.join(self.tmp_data_dir, 'input_data.npy'), inputs)
os.system('onnx_infer --save_dir=' + self.tmp_data_dir)
return
......@@ -457,7 +456,6 @@ class ONNXOpMapper(OpMapper):
def Unsqueeze(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes')
if len(val_x.out_shapes[0]) == 0:
node.fluid_code.add_layer('assign',
inputs=val_x,
......@@ -491,6 +489,7 @@ class ONNXOpMapper(OpMapper):
assert dtype == output_dtype, 'tensor dtype unmatches storage dtype'
shape = node.get_attr('shape', None)
if shape is None:
shape = val_output.out_shapes[0]
if shape is None:
......@@ -536,11 +535,16 @@ class ONNXOpMapper(OpMapper):
def Expand(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_shape = self.graph.get_input_node(node, idx=1, copy=True)
if len(val_shape.outputs) == 1:
self.omit_nodes.append(val_shape.layer_name)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
out_shape = node.out_shapes[0]
val_x_dtype = val_x.dtype
name_ones = node.layer_name + '_ones'
attr_ones = {'shape': out_shape, 'dtype': string('int64')}
attr_ones = {'shape': out_shape, 'dtype': string(val_x_dtype)}
node.fluid_code.add_layer('ones',
inputs=None,
output=name_ones,
......@@ -724,8 +728,16 @@ class ONNXOpMapper(OpMapper):
# catch dynamic graph shape
if isinstance(val_shape, ONNXGraphNode):
shape, _, _ = self.get_dynamic_shape(val_shape.layer_name)
attr['actual_shape'] = val_shape.layer_name
if val_shape.dtype == 'int64':
val_shape_cast = val_shape.layer_name + '_cast'
node.fluid_code.add_layer('cast',
inputs=val_shape,
output=val_shape_cast,
param_attr={'dtype': string('int32')})
attr['actual_shape'] = val_shape_cast
else:
attr['actual_shape'] = val_shape
if shape is None:
shape = val_reshaped.out_shapes[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册