diff --git a/akg b/akg index c460176523d039c8995f1d71089753725ebc0792..df57a6cf9450e347d1854687d1fe66a420ee3b35 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit c460176523d039c8995f1d71089753725ebc0792 +Subproject commit df57a6cf9450e347d1854687d1fe66a420ee3b35 diff --git a/mindspore/ccsrc/kernel/kernel_query.cc b/mindspore/ccsrc/kernel/kernel_query.cc index 6538c28765c298ae4386a1735c39d096b101a18d..16aa8de3c28d1642c091f44e1950a70471facc3c 100755 --- a/mindspore/ccsrc/kernel/kernel_query.cc +++ b/mindspore/ccsrc/kernel/kernel_query.cc @@ -23,6 +23,7 @@ #include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" #include "kernel/akg/akg_kernel_metadata.h" #include "session/anf_runtime_algorithm.h" +#include "utils/context/ms_context.h" namespace mindspore { namespace kernel { @@ -96,6 +97,12 @@ void KernelQuery(const CNodePtr &kernel_node, std::vectorenable_graph_kernel() && IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) { + kernel_type = KernelType::AKG_KERNEL; + } + switch (kernel_type) { case KernelType::AKG_KERNEL: AkgMetadataInfo(kernel_node, kernel_info_list); diff --git a/mindspore/ops/_op_impl/akg/__init__.py b/mindspore/ops/_op_impl/akg/__init__.py index f38b99f5e4f02f75bcff0c0a147761e77a013383..fd86dbf999160ecbced337b9f2427caaece7fd28 100644 --- a/mindspore/ops/_op_impl/akg/__init__.py +++ b/mindspore/ops/_op_impl/akg/__init__.py @@ -47,6 +47,7 @@ from .gather_v2 import _gather_v2_akg from .less import _less_akg from .log import _log_akg from .matmul import _matmul_akg +from .batchmatmul import _batchmatmul_akg from .max_pool_grad_with_argmax import _max_pool_grad_with_argmax_akg from .max_pool_with_argmax import _max_pool_with_argmax_akg from .max import _max_akg diff --git a/mindspore/ops/_op_impl/akg/batchmatmul.py b/mindspore/ops/_op_impl/akg/batchmatmul.py new file mode 100644 index 0000000000000000000000000000000000000000..f5da71aa25e7634bb6515ef7516d73d21daf371e --- /dev/null +++ b/mindspore/ops/_op_impl/akg/batchmatmul.py @@ -0,0 +1,73 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BatchMatMul op""" +from mindspore.ops.op_info_register import op_info_register + + +@op_info_register("""{ + "op_name": "BatchMatMul", + "imply_type": "AutoDiff", + "fusion_type": "OPAQUE", + "attr": [ + { + "name": "transpose_a", + "param_type": "optional", + "type": "bool" + }, + { + "name": "transpose_b", + "param_type": "optional", + "type": "bool" + } + ], + "inputs": [ + { + "index": 0, + "dtype": [ + "float16" + ], + "format": [ + "FRACTAL_NZ" + ], + "name": "x1" + }, + { + "index": 1, + "dtype": [ + "float16" + ], + "format": [ + "FRACTAL_NZ" + ], + "name": "x2" + } + ], + "outputs": [ + { + "index": 0, + "dtype": [ + "float16" + ], + "format": [ + "FRACTAL_NZ" + ], + "name": "output" + } + ] +}""") +def _batchmatmul_akg(): + """BatchMatMul AKG register""" + return