diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 9e69cc7445cda0a634523a040267b6f98edc6558..e758a20b3534716b293a5b3f5499363de68831f1 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -77,6 +77,7 @@ static std::map tbe_func_adapter_map = { {"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"}, {"resize_nearest_neighbor_grad", "resize_nearest_neighbor_v2_grad_d"}, {"pad", "pad_d"}, + {"argmax", "arg_max_d"}, {"space_to_batch", "space_to_batch_d"}, {"batch_to_space", "batch_to_space_d"}, {"resize_bilinear", "resize_bilinear_v2_d"}, diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index c8aa30f2c25198ed757191e6825b09931a245d00..aa604d18de6346b14f4f0a5a00056cbdfb15eb12 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -175,6 +175,7 @@ from .bounding_box_decode import _bounding_box_decode_tbe from .bounding_box_encode import _bounding_box_encode_tbe from .check_valid import _check_valid_tbe from .iou import _iou_tbe +from .arg_max import _arg_max_tbe from .nms_with_mask import nms_with_mask_op_info from .random_choice_with_mask import random_choice_with_mask_op_info from .sgd import sgd_op_info diff --git a/mindspore/ops/_op_impl/tbe/arg_max.py b/mindspore/ops/_op_impl/tbe/arg_max.py new file mode 100644 index 0000000000000000000000000000000000000000..b91df1cfb62fdc9b1ba2be0c45a9949735b36a5d --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/arg_max.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Argmax op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +arg_max_op_info = TBERegOp("Argmax") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("arg_max_d.so") \ + .compute_cost(10) \ + .kernel_name("arg_max_d") \ + .partial_flag(True) \ + .attr("axis", "required", "int", "all") \ + .attr("output_dtype", "optional", "type", "all") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default) \ + .get_op_info() + + +@op_info_register(arg_max_op_info) +def _arg_max_tbe(): + """Argmax TBE register""" + return diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index b27865c5285b1354aa7e3f71d9ebce5ea4b4b944..e8cdbe5e90f1b08cb1e6c2370c0243e2f1f002b5 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -951,8 +951,8 @@ class Argmax(PrimitiveWithInfer): Args: axis (int): Axis on which Argmax operation applies. Default: -1. - output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32` and - `mindspore.dtype.int64`. Default: `mindspore.dtype.int64`. + output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32`. + Default: `mindspore.dtype.int32`. Inputs: - **input_x** (Tensor) - Input tensor. @@ -961,12 +961,12 @@ class Argmax(PrimitiveWithInfer): Tensor, indices of the max value of input tensor across the axis. Examples: - >>> input_x = Tensor(np.array([2.0, 3.1, 1.2])) + >>> input_x = Tensor(np.array([2.0, 3.1, 1.2]), mindspore.float32) >>> index = P.Argmax(output_type=mindspore.int32)(input_x) """ @prim_attr_register - def __init__(self, axis=-1, output_type=mstype.int64): + def __init__(self, axis=-1, output_type=mstype.int32): """init Argmax""" self.init_prim_io_names(inputs=['x'], outputs=['output']) validator.check_value_type("axis", axis, [int], self.name)