提交 88215d00 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!905 add topk op for aicpu

Merge pull request !905 from yanzhenxiang2020/add_topkop_for_aicpu
......@@ -111,6 +111,9 @@ bool AicpuOpKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::
CreateCpuKernelInfo(inputs, outputs);
auto *stream = reinterpret_cast<rtStream_t *>(stream_ptr);
if (node_name_ == "TopK") {
node_name_ = "TopKV2";
}
MS_LOG(INFO) << "Aicpu launch, node_so_:" << node_so_ << ", node name:" << node_name_
<< ", args_size:" << args_.length();
if (rtCpuKernelLaunch(reinterpret_cast<const void *>(node_so_.c_str()),
......@@ -137,6 +140,9 @@ vector<TaskInfoPtr> AicpuOpKernelMod::GenTask(const std::vector<AddressPtr> &inp
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs),
[](const AddressPtr &output) -> void * { return output->addr; });
if (node_name_ == "TopK") {
node_name_ = "TopKV2";
}
AicpuTaskInfoPtr task_info_ptr = make_shared<ge::model_runner::AicpuTaskInfo>(
stream_id, node_so_, node_name_, node_def_str_, input_data_addrs, output_data_addrs);
......
......@@ -568,6 +568,12 @@ void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<Ke
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> parse_info_list;
if (AnfAlgo::GetCNodeName(kernel_node) == kTopKOpName && AnfAlgo::GetNodeAttr<bool>(kernel_node, "sorted") == false) {
MS_LOG(INFO) << "will select aicpu topk.";
return;
}
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE);
if (op_info_ptr == nullptr) {
......
......@@ -17,3 +17,4 @@ from .init_data_set_queue import _init_data_set_queue_aicpu
from .dropout_genmask import _dropout_genmask_aicpu
from .get_next import _get_next_aicpu
from .print_tensor import _print_aicpu
from .topk import _top_k_aicpu
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TopK op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
top_k_op_info = AiCPURegOp("TopK") \
.fusion_type("OPAQUE") \
.attr("sorted", "bool")\
.input(0, "intput", "required") \
.input(1, "k", "required") \
.output(0, "values", "required") \
.output(1, "indices", "required") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(top_k_op_info)
def _top_k_aicpu():
"""TopK aicpu register"""
return
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册