diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index e984e10ab58cc2edad599c55ce0ac8e711e8520f..9ffe233161627b2c549446dbec87427b0f4443e3 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -793,6 +793,14 @@ class OpSet9(): self.paddle_graph.add_layer( 'paddle.multiply', inputs=inputs_dict, outputs=[node.name]) + @print_mapping_info + def GatherND(self, node): + x = self.graph.get_input_node(node, idx=0, copy=True) + index = self.graph.get_input_node(node, idx=1, copy=True) + inputs = {'x': x.name, 'index': index.name} + self.paddle_graph.add_layer( + "paddle.gather_nd", inputs=inputs, outputs=[node.name]) + @print_mapping_info def Gather(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) @@ -1345,28 +1353,50 @@ class OpSet9(): if split is None: split = len(node.outputs) axis = node.get_attr('axis', 0) - layer_attrs = { - 'num_or_sections': split, - 'axis': axis, - } - outputs_list = list() - if isinstance(split, list) or isinstance(split, tuple): - if len(split) == 1: - outputs_list.append(node.name) - else: - for i in range(len(split)): + if split is None: + split_num = len(node.layer.output) + layer_attrs = { + 'num_or_sections': split_num, + 'axis': axis, + } + outputs_list = list() + for i in range(len(node.layer.output)): + if hasattr(node, 'index'): outputs_list.append("{}_p{}".format(node.layer_name, i)) + else: + outputs_list.append("{}".format(node.layer_name)) + if split_num > 1: + self.paddle_graph.add_layer( + 'paddle.split', + inputs={"x": val_x.name}, + outputs=outputs_list, + **layer_attrs) + else: + self.paddle_graph.add_layer( + "paddle.cast", + inputs={"x": val_x.name}, + outputs=outputs_list, + dtype=string(val_x.dtype)) + else: - if len(node.outputs) == 1: - outputs_list.append(node.name) + layer_attrs = { + 'num_or_sections': split, + 'axis': axis, + } + outputs_list = list() + if isinstance(split, list) or isinstance(split, tuple): + if len(split) == 1: + outputs_list.append(node.name) + else: + for i in range(len(split)): + outputs_list.append("{}_p{}".format(node.layer_name, i)) else: - for i in range(len(node.outputs)): - outputs_list.append("{}_p{}".format(node.layer_name, i)) - self.paddle_graph.add_layer( - 'paddle.split', - inputs={"x": val_x.name}, - outputs=outputs_list, - **layer_attrs) + outputs_list.append(node.name) + self.paddle_graph.add_layer( + 'paddle.split', + inputs={"x": val_x.name}, + outputs=outputs_list, + **layer_attrs) @print_mapping_info def Reshape(self, node):