提交 b780979b 编写于 作者: S SunAhong1993

fix onnx

上级 bbbbd42e
...@@ -210,9 +210,12 @@ class PaddleGraph(object): ...@@ -210,9 +210,12 @@ class PaddleGraph(object):
if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get( if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get(
layer_id, 0) == 0 and layer.kernel != "prim.assert" \ layer_id, 0) == 0 and layer.kernel != "prim.assert" \
and layer.kernel != "prim.exception" \ and layer.kernel != "prim.exception" \
and layer.kernel != "prim.warnings" and layer.outputs[0] not in self.outputs: and layer.kernel != "prim.warnings" \
and layer.outputs[0] not in self.outputs:
if layer.kernel == "paddle.to_tensor" and layer.outputs[0] in self.inputs_info: if layer.kernel == "paddle.to_tensor" and layer.outputs[0] in self.inputs_info:
self.inputs_info.pop(layer.outputs[0]) self.inputs_info.pop(layer.outputs[0])
if layer.outputs[0] in self.inputs:
self.inputs.pop(self.inputs.index(layer.outputs[0]))
invalid_list.append(layer_id) invalid_list.append(layer_id)
for layer_id in invalid_list: for layer_id in invalid_list:
self.layers.pop(layer_id) self.layers.pop(layer_id)
...@@ -355,6 +358,8 @@ class PaddleGraph(object): ...@@ -355,6 +358,8 @@ class PaddleGraph(object):
edges_in = self.edges_in.get(layer_id, []) edges_in = self.edges_in.get(layer_id, [])
edges_out = self.edges_out.get(layer_id, []) edges_out = self.edges_out.get(layer_id, [])
if len(edges_in) == 0 and len(edges_out) == 0 and layer.outputs[0] not in self.outputs: if len(edges_in) == 0 and len(edges_out) == 0 and layer.outputs[0] not in self.outputs:
if layer.outputs[0] in self.inputs:
self.inputs.pop(self.inputs.index(layer.outputs[0]))
continue continue
line = "" line = ""
......
...@@ -31,6 +31,7 @@ import numpy as np ...@@ -31,6 +31,7 @@ import numpy as np
from copy import deepcopy from copy import deepcopy
import logging as _logging import logging as _logging
import os import os
import copy
default_op_domain = 'ai.onnx' default_op_domain = 'ai.onnx'
_logger = _logging.getLogger(__name__) _logger = _logging.getLogger(__name__)
...@@ -125,6 +126,17 @@ class ONNXGraphDataNode(GraphNode): ...@@ -125,6 +126,17 @@ class ONNXGraphDataNode(GraphNode):
shape.append(dim.dim_value) shape.append(dim.dim_value)
out_shapes.append(shape) out_shapes.append(shape)
return out_shapes return out_shapes
elif isinstance(self.layer, TensorProto):
values = self.layer.dims
out_shapes = list()
shape = list()
for dim in values:
if dim == 0:
shape.append(-1)
else:
shape.append(dim)
out_shapes.append(shape)
return out_shapes
else: else:
values = self.layer.dims values = self.layer.dims
out_shapes = list() out_shapes = list()
...@@ -227,8 +239,6 @@ class ONNXGraph(Graph): ...@@ -227,8 +239,6 @@ class ONNXGraph(Graph):
inner_nodes = self.get_inner_nodes() inner_nodes = self.get_inner_nodes()
for ipt_vi in self.graph.input: for ipt_vi in self.graph.input:
if ipt_vi.name not in inner_nodes: if ipt_vi.name not in inner_nodes:
if len(ipt_vi.type.tensor_type.shape.dim) == 0:
continue
self.check_input_shape(ipt_vi) self.check_input_shape(ipt_vi)
self.place_holder_nodes.append(ipt_vi.name) self.place_holder_nodes.append(ipt_vi.name)
...@@ -289,7 +299,7 @@ class ONNXGraph(Graph): ...@@ -289,7 +299,7 @@ class ONNXGraph(Graph):
#generate topo #generate topo
super(ONNXGraph, self).build() super(ONNXGraph, self).build()
self.input_nodes = self.place_holder_nodes self.input_nodes = copy.deepcopy(self.place_holder_nodes)
def build_connection(self, layer_name, node): def build_connection(self, layer_name, node):
""" """
......
...@@ -299,6 +299,9 @@ class OpSet9(): ...@@ -299,6 +299,9 @@ class OpSet9():
attrs.update({"align_corners": False, attrs.update({"align_corners": False,
"mode": string(mode), "mode": string(mode),
"align_mode": 1}) "align_mode": 1})
val_x_shape = val_x.out_shapes[0]
if mode == "linear" and len(val_x_shape) == 4:
attrs["mode"] = string("bilinear")
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel="paddle.nn.functional.interpolate", kernel="paddle.nn.functional.interpolate",
inputs=inputs, inputs=inputs,
...@@ -1386,9 +1389,7 @@ class OpSet9(): ...@@ -1386,9 +1389,7 @@ class OpSet9():
outputs=[output_name]) outputs=[output_name])
else: else:
if mode == 'channel' and len(shape_slope) == 1: if mode == 'channel' and len(shape_slope) == 1:
# paddle params shape need be [1, channel]
slope_data = _const_weight_or_none(val_slope) slope_data = _const_weight_or_none(val_slope)
slope_data = np.reshape(slope_data, [1] + shape_slope)
self.weights[val_slope.name] = slope_data self.weights[val_slope.name] = slope_data
num_parameters = val_x.out_shapes[0][1] num_parameters = val_x.out_shapes[0][1]
else: else:
......
...@@ -289,6 +289,9 @@ class OpSet9(): ...@@ -289,6 +289,9 @@ class OpSet9():
attrs.update({"align_corners": False, attrs.update({"align_corners": False,
"mode": string(mode), "mode": string(mode),
"align_mode": 1}) "align_mode": 1})
val_x_shape = val_x.out_shapes[0]
if mode == "linear" and len(val_x_shape) == 4:
attrs["mode"] = string("bilinear")
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel="paddle.nn.functional.interpolate", kernel="paddle.nn.functional.interpolate",
inputs=inputs, inputs=inputs,
...@@ -1323,9 +1326,6 @@ class OpSet9(): ...@@ -1323,9 +1326,6 @@ class OpSet9():
@print_mapping_info @print_mapping_info
def PRelu(self, node): def PRelu(self, node):
op_name = name_generator("prelu", self.nn_name2id)
output_name = node.name
layer_outputs = [op_name, output_name]
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_slope = self.graph.get_input_node(node, idx=1, copy=True) val_slope = self.graph.get_input_node(node, idx=1, copy=True)
...@@ -1342,12 +1342,13 @@ class OpSet9(): ...@@ -1342,12 +1342,13 @@ class OpSet9():
outputs=[node.name], outputs=[node.name],
mode="element") mode="element")
else: else:
if mode == 'channel' and len(shape_slope) == 1: if mode == 'channel':
# paddle params shape need be [1, channel] if len(shape_slope) > 1:
slope_data = _const_weight_or_none(val_slope) self.paddle_graph.add_layer(
slope_data = np.reshape(slope_data, [1] + shape_slope) "paddle.reshape",
self.params[val_slope.name] = slope_data inputs={"x": val_slope.name},
outputs=[val_slope.name],
shape=[shape_slope[0]])
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.nn.functional.prelu", "paddle.nn.functional.prelu",
inputs={"x": val_x.name, inputs={"x": val_x.name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册