diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index caa9a46db012b515ba7c315cab004b544f28ca4b..fb2a2d9d24a953694b323cc70d4162a6cd58a5ec 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -2577,27 +2577,44 @@ class OpSet9(): def TopK(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) val_k = self.graph.get_input_node(node, idx=1, copy=True) - if val_k.dtype != "int32": - self.paddle_graph.add_layer( - "paddle.cast", - inputs={"x": val_k.name}, - outputs=[val_k.name], - dtype=string('int32')) layer_attrs = dict() layer_attrs["axis"] = node.get_attr('axis', -1) layer_attrs["largest"] = True if node.get_attr('largest', 1) == 1 else False layer_attrs["sorted"] = True if node.get_attr('sorted', 1) == 1 else False - self.paddle_graph.add_layer( - "paddle.topk", - inputs={"x": val_x.name, - "k": val_k.name}, - outputs=[ - "{}_p{}".format(node.layer_name, 0), - "{}_p{}".format(node.layer_name, 1) - ], - **layer_attrs) + is_k_attr = False + k = _const_weight_or_none(val_k) + if isinstance(k, (list, tuple, np.ndarray)): + k = k[0] + if k is not None: + is_k_attr = True + layer_attrs["k"] = k + if is_k_attr: + self.paddle_graph.add_layer( + "paddle.topk", + inputs={"x": val_x.name}, + outputs=[ + "{}_p{}".format(node.layer_name, 0), + "{}_p{}".format(node.layer_name, 1) + ], + **layer_attrs) + else: + if val_k.dtype != "int32": + self.paddle_graph.add_layer( + "paddle.cast", + inputs={"x": val_k.name}, + outputs=[val_k.name], + dtype=string('int32')) + self.paddle_graph.add_layer( + "paddle.topk", + inputs={"x": val_x.name, + "k": val_k.name}, + outputs=[ + "{}_p{}".format(node.layer_name, 0), + "{}_p{}".format(node.layer_name, 1) + ], + **layer_attrs) @print_mapping_info def LRN(self, node):