diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index f10fff742b8f9d51c5540053511ea50ee07b2c36..9c205a48605c6a5952cef2f68e18803da95fd2c6 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -1882,30 +1882,13 @@ class OpSet9(): @print_mapping_info def NonZero(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) - val_x_dim = len(val_x.out_shapes[0]) - if val_x_dim == 1: - self.paddle_graph.add_layer( - "paddle.nonzero", - inputs={"x": val_x.name}, - outputs=[val_x.name]) - self.paddle_graph.add_layer( - "paddle.transpose", - inputs={"x": val_x.name}, - outputs=[node.layer_name], - perm=[1, 0]) - if val_x_dim > 1: - self.paddle_graph.add_layer( - "paddle.nonzero", - inputs={"x": val_x.name}, - outputs=[val_x.name]) - self.paddle_graph.add_layer( - "paddle.split", - inputs={"x": val_x.name}, - outputs=[val_x.name], - num_or_sections=1, - axis=val_x_dim) - self.paddle_graph.add_layer( - "paddle.concat", inputs={"x": val_x.name}, outputs=[node.name]) + self.paddle_graph.add_layer( + "paddle.nonzero", + inputs={"x": val_x.name}, + outputs=[val_x.name], + as_tuple=True) + self.paddle_graph.add_layer( + "paddle.concat", inputs={"x": val_x.name}, outputs=[node.name]) @print_mapping_info def Identity(self, node): @@ -2579,27 +2562,42 @@ class OpSet9(): def TopK(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) val_k = self.graph.get_input_node(node, idx=1, copy=True) - if val_k.dtype != "int32": - self.paddle_graph.add_layer( - "paddle.cast", - inputs={"x": val_k.name}, - outputs=[val_k.name], - dtype=string('int32')) layer_attrs = dict() layer_attrs["axis"] = node.get_attr('axis', -1) layer_attrs["largest"] = True if node.get_attr('largest', 1) == 1 else False layer_attrs["sorted"] = True if node.get_attr('sorted', 1) == 1 else False - self.paddle_graph.add_layer( - "paddle.topk", - inputs={"x": val_x.name, - "k": val_k.name}, - outputs=[ - "{}_p{}".format(node.layer_name, 0), - "{}_p{}".format(node.layer_name, 1) - ], - **layer_attrs) + k = _const_weight_or_none(val_k) + if isinstance(k, (list, tuple, np.ndarray)): + k = k[0] + # If k can get the value directly, it is used as an attribute; otherwise it is used as an input tensor + if k is not None: + layer_attrs["k"] = k + self.paddle_graph.add_layer( + "paddle.topk", + inputs={"x": val_x.name}, + outputs=[ + "{}_p{}".format(node.layer_name, 0), + "{}_p{}".format(node.layer_name, 1) + ], + **layer_attrs) + else: + if val_k.dtype != "int32": + self.paddle_graph.add_layer( + "paddle.cast", + inputs={"x": val_k.name}, + outputs=[val_k.name], + dtype=string('int32')) + self.paddle_graph.add_layer( + "paddle.topk", + inputs={"x": val_x.name, + "k": val_k.name}, + outputs=[ + "{}_p{}".format(node.layer_name, 0), + "{}_p{}".format(node.layer_name, 1) + ], + **layer_attrs) @print_mapping_info def LRN(self, node):