提交 b780979b 编写于 作者: S SunAhong1993

fix onnx

上级 bbbbd42e
......@@ -210,9 +210,12 @@ class PaddleGraph(object):
if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get(
layer_id, 0) == 0 and layer.kernel != "prim.assert" \
and layer.kernel != "prim.exception" \
and layer.kernel != "prim.warnings" and layer.outputs[0] not in self.outputs:
and layer.kernel != "prim.warnings" \
and layer.outputs[0] not in self.outputs:
if layer.kernel == "paddle.to_tensor" and layer.outputs[0] in self.inputs_info:
self.inputs_info.pop(layer.outputs[0])
if layer.outputs[0] in self.inputs:
self.inputs.pop(self.inputs.index(layer.outputs[0]))
invalid_list.append(layer_id)
for layer_id in invalid_list:
self.layers.pop(layer_id)
......@@ -355,6 +358,8 @@ class PaddleGraph(object):
edges_in = self.edges_in.get(layer_id, [])
edges_out = self.edges_out.get(layer_id, [])
if len(edges_in) == 0 and len(edges_out) == 0 and layer.outputs[0] not in self.outputs:
if layer.outputs[0] in self.inputs:
self.inputs.pop(self.inputs.index(layer.outputs[0]))
continue
line = ""
......
......@@ -31,6 +31,7 @@ import numpy as np
from copy import deepcopy
import logging as _logging
import os
import copy
default_op_domain = 'ai.onnx'
_logger = _logging.getLogger(__name__)
......@@ -125,6 +126,17 @@ class ONNXGraphDataNode(GraphNode):
shape.append(dim.dim_value)
out_shapes.append(shape)
return out_shapes
elif isinstance(self.layer, TensorProto):
values = self.layer.dims
out_shapes = list()
shape = list()
for dim in values:
if dim == 0:
shape.append(-1)
else:
shape.append(dim)
out_shapes.append(shape)
return out_shapes
else:
values = self.layer.dims
out_shapes = list()
......@@ -227,8 +239,6 @@ class ONNXGraph(Graph):
inner_nodes = self.get_inner_nodes()
for ipt_vi in self.graph.input:
if ipt_vi.name not in inner_nodes:
if len(ipt_vi.type.tensor_type.shape.dim) == 0:
continue
self.check_input_shape(ipt_vi)
self.place_holder_nodes.append(ipt_vi.name)
......@@ -289,7 +299,7 @@ class ONNXGraph(Graph):
#generate topo
super(ONNXGraph, self).build()
self.input_nodes = self.place_holder_nodes
self.input_nodes = copy.deepcopy(self.place_holder_nodes)
def build_connection(self, layer_name, node):
"""
......
......@@ -299,6 +299,9 @@ class OpSet9():
attrs.update({"align_corners": False,
"mode": string(mode),
"align_mode": 1})
val_x_shape = val_x.out_shapes[0]
if mode == "linear" and len(val_x_shape) == 4:
attrs["mode"] = string("bilinear")
self.paddle_graph.add_layer(
kernel="paddle.nn.functional.interpolate",
inputs=inputs,
......@@ -1386,9 +1389,7 @@ class OpSet9():
outputs=[output_name])
else:
if mode == 'channel' and len(shape_slope) == 1:
# paddle params shape need be [1, channel]
slope_data = _const_weight_or_none(val_slope)
slope_data = np.reshape(slope_data, [1] + shape_slope)
self.weights[val_slope.name] = slope_data
num_parameters = val_x.out_shapes[0][1]
else:
......
......@@ -289,6 +289,9 @@ class OpSet9():
attrs.update({"align_corners": False,
"mode": string(mode),
"align_mode": 1})
val_x_shape = val_x.out_shapes[0]
if mode == "linear" and len(val_x_shape) == 4:
attrs["mode"] = string("bilinear")
self.paddle_graph.add_layer(
kernel="paddle.nn.functional.interpolate",
inputs=inputs,
......@@ -1323,9 +1326,6 @@ class OpSet9():
@print_mapping_info
def PRelu(self, node):
op_name = name_generator("prelu", self.nn_name2id)
output_name = node.name
layer_outputs = [op_name, output_name]
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_slope = self.graph.get_input_node(node, idx=1, copy=True)
......@@ -1342,12 +1342,13 @@ class OpSet9():
outputs=[node.name],
mode="element")
else:
if mode == 'channel' and len(shape_slope) == 1:
# paddle params shape need be [1, channel]
slope_data = _const_weight_or_none(val_slope)
slope_data = np.reshape(slope_data, [1] + shape_slope)
self.params[val_slope.name] = slope_data
if mode == 'channel':
if len(shape_slope) > 1:
self.paddle_graph.add_layer(
"paddle.reshape",
inputs={"x": val_slope.name},
outputs=[val_slope.name],
shape=[shape_slope[0]])
self.paddle_graph.add_layer(
"paddle.nn.functional.prelu",
inputs={"x": val_x.name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册