From 9663d20f05eaf6ef4990d8caea63ee4a196c0838 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Wed, 1 Jun 2022 22:06:38 +0800 Subject: [PATCH] fixed topk --- .../op_mapper/onnx2paddle/opset9/opset.py | 47 +++++++++++++------ 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index caa9a46..fb2a2d9 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -2577,27 +2577,44 @@ class OpSet9(): def TopK(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) val_k = self.graph.get_input_node(node, idx=1, copy=True) - if val_k.dtype != "int32": - self.paddle_graph.add_layer( - "paddle.cast", - inputs={"x": val_k.name}, - outputs=[val_k.name], - dtype=string('int32')) layer_attrs = dict() layer_attrs["axis"] = node.get_attr('axis', -1) layer_attrs["largest"] = True if node.get_attr('largest', 1) == 1 else False layer_attrs["sorted"] = True if node.get_attr('sorted', 1) == 1 else False - self.paddle_graph.add_layer( - "paddle.topk", - inputs={"x": val_x.name, - "k": val_k.name}, - outputs=[ - "{}_p{}".format(node.layer_name, 0), - "{}_p{}".format(node.layer_name, 1) - ], - **layer_attrs) + is_k_attr = False + k = _const_weight_or_none(val_k) + if isinstance(k, (list, tuple, np.ndarray)): + k = k[0] + if k is not None: + is_k_attr = True + layer_attrs["k"] = k + if is_k_attr: + self.paddle_graph.add_layer( + "paddle.topk", + inputs={"x": val_x.name}, + outputs=[ + "{}_p{}".format(node.layer_name, 0), + "{}_p{}".format(node.layer_name, 1) + ], + **layer_attrs) + else: + if val_k.dtype != "int32": + self.paddle_graph.add_layer( + "paddle.cast", + inputs={"x": val_k.name}, + outputs=[val_k.name], + dtype=string('int32')) + self.paddle_graph.add_layer( + "paddle.topk", + inputs={"x": val_x.name, + "k": val_k.name}, + outputs=[ + "{}_p{}".format(node.layer_name, 0), + "{}_p{}".format(node.layer_name, 1) + ], + **layer_attrs) @print_mapping_info def LRN(self, node): -- GitLab