diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 1d0770a5225ffcb402ee0453215111007b415a60..3f0f0443952226f5dcf09c81278158ce4e94f540 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -1880,30 +1880,13 @@ class OpSet9(): @print_mapping_info def NonZero(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) - val_x_dim = len(val_x.out_shapes[0]) - if val_x_dim == 1: - self.paddle_graph.add_layer( - "paddle.nonzero", - inputs={"x": val_x.name}, - outputs=[val_x.name]) - self.paddle_graph.add_layer( - "paddle.transpose", - inputs={"x": val_x.name}, - outputs=[node.layer_name], - perm=[1, 0]) - if val_x_dim > 1: - self.paddle_graph.add_layer( - "paddle.nonzero", - inputs={"x": val_x.name}, - outputs=[val_x.name]) - self.paddle_graph.add_layer( - "paddle.split", - inputs={"x": val_x.name}, - outputs=[val_x.name], - num_or_sections=1, - axis=val_x_dim) - self.paddle_graph.add_layer( - "paddle.concat", inputs={"x": val_x.name}, outputs=[node.name]) + self.paddle_graph.add_layer( + "paddle.nonzero", + inputs={"x": val_x.name}, + outputs=[val_x.name], + as_tuple=True) + self.paddle_graph.add_layer( + "paddle.concat", inputs={"x": val_x.name}, outputs=[node.name]) @print_mapping_info def Identity(self, node):