From 64a287a02953fd4e3ac407b53bbc3b07a34fecc3 Mon Sep 17 00:00:00 2001 From: jiangjinsheng Date: Tue, 19 May 2020 15:30:30 +0800 Subject: [PATCH] fixed arg_max --- mindspore/ops/_op_impl/tbe/arg_max.py | 4 ++-- mindspore/ops/operations/array_ops.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mindspore/ops/_op_impl/tbe/arg_max.py b/mindspore/ops/_op_impl/tbe/arg_max.py index dbfe2ad92..b91df1cfb 100644 --- a/mindspore/ops/_op_impl/tbe/arg_max.py +++ b/mindspore/ops/_op_impl/tbe/arg_max.py @@ -23,8 +23,8 @@ arg_max_op_info = TBERegOp("Argmax") \ .compute_cost(10) \ .kernel_name("arg_max_d") \ .partial_flag(True) \ - .attr("dimension", "required", "int", "all") \ - .attr("dtype", "optional", "type", "all") \ + .attr("axis", "required", "int", "all") \ + .attr("output_dtype", "optional", "type", "all") \ .input(0, "x", False, "required", "all") \ .output(0, "y", False, "required", "all") \ .dtype_format(DataType.F16_Default, DataType.I32_Default) \ diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index b27865c52..e8cdbe5e9 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -951,8 +951,8 @@ class Argmax(PrimitiveWithInfer): Args: axis (int): Axis on which Argmax operation applies. Default: -1. - output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32` and - `mindspore.dtype.int64`. Default: `mindspore.dtype.int64`. + output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32`. + Default: `mindspore.dtype.int32`. Inputs: - **input_x** (Tensor) - Input tensor. @@ -961,12 +961,12 @@ class Argmax(PrimitiveWithInfer): Tensor, indices of the max value of input tensor across the axis. Examples: - >>> input_x = Tensor(np.array([2.0, 3.1, 1.2])) + >>> input_x = Tensor(np.array([2.0, 3.1, 1.2]), mindspore.float32) >>> index = P.Argmax(output_type=mindspore.int32)(input_x) """ @prim_attr_register - def __init__(self, axis=-1, output_type=mstype.int64): + def __init__(self, axis=-1, output_type=mstype.int32): """init Argmax""" self.init_prim_io_names(inputs=['x'], outputs=['output']) validator.check_value_type("axis", axis, [int], self.name) -- GitLab