diff --git a/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py index f0b51d40b288462ae103425a64f4e07564b9ea39..b1c1fd816ef76cee126171b055111bfe764fe705 100644 --- a/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py @@ -1637,3 +1637,16 @@ class OpSet9(): inputs=inputs_dict, outputs=[node.name], **layer_attrs) + + @print_mapping_info + def ArgMax(self, node): + val_x = self.graph.get_input_node(node, idx=0, copy=True) + axis = node.get_attr('axis') + keepdims = False if node.get_attr('keepdims') == 0 else True + layer_attrs = {'axis': axis, + 'keepdim': keepdims} + self.paddle_graph.add_layer( + 'paddle.argmax', + inputs={"x": val_x.name}, + outputs=[node.name], + **layer_attrs) diff --git a/x2paddle/op_mapper/static/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/static/onnx2paddle/opset9/opset.py index 53c28705660019a849fd8e5240402602685b64c0..e4c5439dcc56d297628198fb537b811ef7b0f1ce 100644 --- a/x2paddle/op_mapper/static/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/static/onnx2paddle/opset9/opset.py @@ -1576,4 +1576,17 @@ class OpSet9(): kernel=paddle_op, inputs=layer_inputs, outputs=[node.name], + **layer_attrs) + + @print_mapping_info + def ArgMax(self, node): + val_x = self.graph.get_input_node(node, idx=0, copy=True) + axis = node.get_attr('axis') + keepdims = False if node.get_attr('keepdims') == 0 else True + layer_attrs = {'axis': axis, + 'keepdim': keepdims} + self.paddle_graph.add_layer( + 'paddle.argmax', + inputs={"x": val_x.name}, + outputs=[node.name], **layer_attrs) \ No newline at end of file