diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 9e69cc7445cda0a634523a040267b6f98edc6558..e758a20b3534716b293a5b3f5499363de68831f1 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -77,6 +77,7 @@ static std::map tbe_func_adapter_map = { {"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"}, {"resize_nearest_neighbor_grad", "resize_nearest_neighbor_v2_grad_d"}, {"pad", "pad_d"}, + {"argmax", "arg_max_d"}, {"space_to_batch", "space_to_batch_d"}, {"batch_to_space", "batch_to_space_d"}, {"resize_bilinear", "resize_bilinear_v2_d"}, diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index c8aa30f2c25198ed757191e6825b09931a245d00..aa604d18de6346b14f4f0a5a00056cbdfb15eb12 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -175,6 +175,7 @@ from .bounding_box_decode import _bounding_box_decode_tbe from .bounding_box_encode import _bounding_box_encode_tbe from .check_valid import _check_valid_tbe from .iou import _iou_tbe +from .arg_max import _arg_max_tbe from .nms_with_mask import nms_with_mask_op_info from .random_choice_with_mask import random_choice_with_mask_op_info from .sgd import sgd_op_info diff --git a/mindspore/ops/_op_impl/tbe/arg_max.py b/mindspore/ops/_op_impl/tbe/arg_max.py new file mode 100644 index 0000000000000000000000000000000000000000..dbfe2ad923b7eedb1239980b1949a9d10ecdee61 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/arg_max.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Argmax op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +arg_max_op_info = TBERegOp("Argmax") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("arg_max_d.so") \ + .compute_cost(10) \ + .kernel_name("arg_max_d") \ + .partial_flag(True) \ + .attr("dimension", "required", "int", "all") \ + .attr("dtype", "optional", "type", "all") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default) \ + .get_op_info() + + +@op_info_register(arg_max_op_info) +def _arg_max_tbe(): + """Argmax TBE register""" + return