未验证 提交 bf7fd504 编写于 作者: Y yeliang2258 提交者: GitHub

fix split op in onnx (#645)

* fix expand op in onnx

* remove useless info

* fix split and add GatherND

* fix

* test revert

* update

* update

* reverse

* reverse expand
上级 b229cbd0
...@@ -793,6 +793,14 @@ class OpSet9(): ...@@ -793,6 +793,14 @@ class OpSet9():
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.multiply', inputs=inputs_dict, outputs=[node.name]) 'paddle.multiply', inputs=inputs_dict, outputs=[node.name])
@print_mapping_info
def GatherND(self, node):
x = self.graph.get_input_node(node, idx=0, copy=True)
index = self.graph.get_input_node(node, idx=1, copy=True)
inputs = {'x': x.name, 'index': index.name}
self.paddle_graph.add_layer(
"paddle.gather_nd", inputs=inputs, outputs=[node.name])
@print_mapping_info @print_mapping_info
def Gather(self, node): def Gather(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
...@@ -1345,28 +1353,50 @@ class OpSet9(): ...@@ -1345,28 +1353,50 @@ class OpSet9():
if split is None: if split is None:
split = len(node.outputs) split = len(node.outputs)
axis = node.get_attr('axis', 0) axis = node.get_attr('axis', 0)
layer_attrs = { if split is None:
'num_or_sections': split, split_num = len(node.layer.output)
'axis': axis, layer_attrs = {
} 'num_or_sections': split_num,
outputs_list = list() 'axis': axis,
if isinstance(split, list) or isinstance(split, tuple): }
if len(split) == 1: outputs_list = list()
outputs_list.append(node.name) for i in range(len(node.layer.output)):
else: if hasattr(node, 'index'):
for i in range(len(split)):
outputs_list.append("{}_p{}".format(node.layer_name, i)) outputs_list.append("{}_p{}".format(node.layer_name, i))
else:
outputs_list.append("{}".format(node.layer_name))
if split_num > 1:
self.paddle_graph.add_layer(
'paddle.split',
inputs={"x": val_x.name},
outputs=outputs_list,
**layer_attrs)
else:
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": val_x.name},
outputs=outputs_list,
dtype=string(val_x.dtype))
else: else:
if len(node.outputs) == 1: layer_attrs = {
outputs_list.append(node.name) 'num_or_sections': split,
'axis': axis,
}
outputs_list = list()
if isinstance(split, list) or isinstance(split, tuple):
if len(split) == 1:
outputs_list.append(node.name)
else:
for i in range(len(split)):
outputs_list.append("{}_p{}".format(node.layer_name, i))
else: else:
for i in range(len(node.outputs)): outputs_list.append(node.name)
outputs_list.append("{}_p{}".format(node.layer_name, i)) self.paddle_graph.add_layer(
self.paddle_graph.add_layer( 'paddle.split',
'paddle.split', inputs={"x": val_x.name},
inputs={"x": val_x.name}, outputs=outputs_list,
outputs=outputs_list, **layer_attrs)
**layer_attrs)
@print_mapping_info @print_mapping_info
def Reshape(self, node): def Reshape(self, node):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册