未验证 提交 f6f87497 编写于 作者: Y yeliang2258 提交者: GitHub

fix bug of clip and prelu (#680)

* fix bug of clip and prelu

* update

* update code

* add get_input_index
上级 57786e37
...@@ -48,6 +48,18 @@ class ONNXGraphNode(GraphNode): ...@@ -48,6 +48,18 @@ class ONNXGraphNode(GraphNode):
self.dtype = None self.dtype = None
self.which_child = {} self.which_child = {}
def get_input_index(self, input_name):
"""
get the index of input_name in layer.input
-1 means input_name is not in the input
"""
index = -1
for i in range(len(self.layer.input)):
if input_name == self.layer.input[i]:
index = i
break
return index
def get_attr_map(self): def get_attr_map(self):
""" """
convert ONNX node attributes to dict convert ONNX node attributes to dict
...@@ -294,7 +306,6 @@ class ONNXGraph(Graph): ...@@ -294,7 +306,6 @@ class ONNXGraph(Graph):
for layer_name, node in self.node_map.items(): for layer_name, node in self.node_map.items():
if isinstance(node, ONNXGraphNode): if isinstance(node, ONNXGraphNode):
self.build_connection(layer_name, node) self.build_connection(layer_name, node)
#generate topo #generate topo
super(ONNXGraph, self).build() super(ONNXGraph, self).build()
......
...@@ -1150,6 +1150,17 @@ class OpSet9(): ...@@ -1150,6 +1150,17 @@ class OpSet9():
outputs=[node.name], outputs=[node.name],
**layer_attrs) **layer_attrs)
@print_mapping_info
def GatherND(self, node):
print(len(node.inputs), node.inputs)
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
self.paddle_graph.add_layer(
"paddle.gather_nd",
inputs={"x": val_x.name,
"index": val_y.name},
outputs=[node.name])
@print_mapping_info @print_mapping_info
def Clip(self, node): def Clip(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
...@@ -1169,23 +1180,40 @@ class OpSet9(): ...@@ -1169,23 +1180,40 @@ class OpSet9():
outputs=[node.name], outputs=[node.name],
**layer_attrs) **layer_attrs)
else: else:
min_ipt = self.graph.get_input_node(node, idx=1, copy=True) if len(node.inputs) == 2:
max_ipt = self.graph.get_input_node(node, idx=2, copy=True) val_ipt = self.graph.get_input_node(node, idx=1, copy=True)
min_value = _const_weight_or_none(min_ipt)
max_value = _const_weight_or_none(max_ipt) index = node.get_input_index(val_ipt.name)
if max_value.shape == (1, ):
max_value = max_value[0] val_value = _const_weight_or_none(val_ipt)
if min_value.shape == (1, ): if val_value.shape == (1, ):
min_value = min_value[0] val_value = val_value[0]
if max_value is not None and min_value is not None:
layer_attrs = {'max': max_value, 'min': min_value} if index == 1:
layer_attrs = {'min': val_value}
if index == 2:
layer_attrs = {'max': val_value}
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.clip', 'paddle.clip',
inputs={"x": val_x.name}, inputs={"x": val_x.name},
outputs=[node.name], outputs=[node.name],
**layer_attrs) **layer_attrs)
else: else:
raise Exception("max_value or min_value can't be None") if len(node.inputs) == 3:
min_ipt = self.graph.get_input_node(node, idx=1, copy=True)
max_ipt = self.graph.get_input_node(node, idx=2, copy=True)
self.paddle_graph.add_layer(
'paddle.clip',
inputs={
"x": val_x.name,
"min": min_ipt.name,
"max": max_ipt.name
},
outputs=[node.name])
else:
raise Exception("max_value or min_value can't be None")
@print_mapping_info @print_mapping_info
def ReduceSum(self, node): def ReduceSum(self, node):
...@@ -1681,9 +1709,9 @@ class OpSet9(): ...@@ -1681,9 +1709,9 @@ class OpSet9():
num_parameters = val_x.out_shapes[0][1] num_parameters = val_x.out_shapes[0][1]
else: else:
num_parameters = 1 num_parameters = 1
slope_data = self.weights[val_slope.name]
_rename_or_remove_weight(self.weights, val_slope.name) _rename_or_remove_weight(self.weights, val_slope.name)
self.weights[op_name + '._weight'] = np.reshape( self.weights[op_name + '._weight'] = np.reshape(slope_data, [1])
self.weights[val_slope.name], [1])
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.nn.PReLU", "paddle.nn.PReLU",
inputs={"x": val_x.name}, inputs={"x": val_x.name},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册