未验证 提交 f6c9db85 编写于 作者: J Jason 提交者: GitHub

Merge pull request #211 from Channingss/develop

fix bug for output_nodes
...@@ -162,6 +162,16 @@ class ONNXGraph(Graph): ...@@ -162,6 +162,16 @@ class ONNXGraph(Graph):
if ipt_data not in inner_nodes: if ipt_data not in inner_nodes:
self.place_holder_nodes.append(ipt_data) self.place_holder_nodes.append(ipt_data)
def get_output_nodes(self):
"""
generate output_nodes node of ONNX model
"""
inner_nodes = self.get_inner_nodes()
output_nodes = [value.name for value in self.model.output]
for opt_data in output_nodes:
if opt_data not in inner_nodes:
self.output_nodes.append(opt_data)
def is_place_holder_nodes(self, layer): def is_place_holder_nodes(self, layer):
""" """
return layer is or not place_holder node return layer is or not place_holder node
......
...@@ -140,6 +140,7 @@ class ONNXOpMapper(OpMapper): ...@@ -140,6 +140,7 @@ class ONNXOpMapper(OpMapper):
model.graph.output.MergeFrom(outputs) model.graph.output.MergeFrom(outputs)
onnx.save(model, os.path.join(self.tmp_data_dir, onnx.save(model, os.path.join(self.tmp_data_dir,
'onnx_model_infer.onnx')) 'onnx_model_infer.onnx'))
os.system('onnx_infer --save_dir=' + self.tmp_data_dir) os.system('onnx_infer --save_dir=' + self.tmp_data_dir)
return return
...@@ -336,7 +337,8 @@ class ONNXOpMapper(OpMapper): ...@@ -336,7 +337,8 @@ class ONNXOpMapper(OpMapper):
node = parameter node = parameter
dtype = node.dtype dtype = node.dtype
shape = node.out_shapes[0] shape = node.out_shapes[0]
if len(node.weight.shape) == 0:
shape = [1]
self.weights[node.layer_name] = node.weight self.weights[node.layer_name] = node.weight
attr = { attr = {
'dtype': string(dtype), 'dtype': string(dtype),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册