提交 fcd63263 编写于 作者: S SunAhong1993

for onnx argmax

上级 55d5eb24
...@@ -1637,3 +1637,16 @@ class OpSet9(): ...@@ -1637,3 +1637,16 @@ class OpSet9():
inputs=inputs_dict, inputs=inputs_dict,
outputs=[node.name], outputs=[node.name],
**layer_attrs) **layer_attrs)
@print_mapping_info
def ArgMax(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axis = node.get_attr('axis')
keepdims = False if node.get_attr('keepdims') == 0 else True
layer_attrs = {'axis': axis,
'keepdim': keepdims}
self.paddle_graph.add_layer(
'paddle.argmax',
inputs={"x": val_x.name},
outputs=[node.name],
**layer_attrs)
...@@ -1576,4 +1576,17 @@ class OpSet9(): ...@@ -1576,4 +1576,17 @@ class OpSet9():
kernel=paddle_op, kernel=paddle_op,
inputs=layer_inputs, inputs=layer_inputs,
outputs=[node.name], outputs=[node.name],
**layer_attrs)
@print_mapping_info
def ArgMax(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axis = node.get_attr('axis')
keepdims = False if node.get_attr('keepdims') == 0 else True
layer_attrs = {'axis': axis,
'keepdim': keepdims}
self.paddle_graph.add_layer(
'paddle.argmax',
inputs={"x": val_x.name},
outputs=[node.name],
**layer_attrs) **layer_attrs)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册