From 06a3deb5e29edf7c7658bb2163f74973456a1250 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Tue, 7 Jun 2022 17:33:02 +0800 Subject: [PATCH] fixed nonzero --- .../op_mapper/onnx2paddle/opset9/opset.py | 31 +++++-------------- 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 1d0770a..3f0f044 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -1880,30 +1880,13 @@ class OpSet9(): @print_mapping_info def NonZero(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) - val_x_dim = len(val_x.out_shapes[0]) - if val_x_dim == 1: - self.paddle_graph.add_layer( - "paddle.nonzero", - inputs={"x": val_x.name}, - outputs=[val_x.name]) - self.paddle_graph.add_layer( - "paddle.transpose", - inputs={"x": val_x.name}, - outputs=[node.layer_name], - perm=[1, 0]) - if val_x_dim > 1: - self.paddle_graph.add_layer( - "paddle.nonzero", - inputs={"x": val_x.name}, - outputs=[val_x.name]) - self.paddle_graph.add_layer( - "paddle.split", - inputs={"x": val_x.name}, - outputs=[val_x.name], - num_or_sections=1, - axis=val_x_dim) - self.paddle_graph.add_layer( - "paddle.concat", inputs={"x": val_x.name}, outputs=[node.name]) + self.paddle_graph.add_layer( + "paddle.nonzero", + inputs={"x": val_x.name}, + outputs=[val_x.name], + as_tuple=True) + self.paddle_graph.add_layer( + "paddle.concat", inputs={"x": val_x.name}, outputs=[node.name]) @print_mapping_info def Identity(self, node): -- GitLab