From 52e97dbb79a19a30479b11080ea919af847e8646 Mon Sep 17 00:00:00 2001 From: WilliamLian Date: Fri, 15 May 2020 12:58:45 +0800 Subject: [PATCH] using device dtype to create transdata kernel build info --- .../device/ascend/kernel_select_ascend.cc | 2 +- mindspore/ccsrc/kernel/kernel_query.cc | 9 +++--- mindspore/ccsrc/kernel/kernel_query.h | 6 ++-- .../ccsrc/kernel/tbe/tbe_kernel_select.cc | 26 ++++++++++++----- .../ccsrc/kernel/tbe/tbe_kernel_select.h | 2 +- .../pre_activate/ascend/ascend_helper.cc | 29 +++++++------------ .../ccsrc/pre_activate/ascend/ascend_helper.h | 4 +-- .../ir_fusion/parameter_and_transop_fusion.cc | 2 +- .../ccsrc/session/anf_runtime_algorithm.cc | 22 ++++++++++++-- 9 files changed, 60 insertions(+), 42 deletions(-) diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index 15401a3f2..655e1dcac 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -506,7 +506,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) { if (select_status == kNoMatched) { MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; - kernel::AICpuQuery(kernel_node, &kernel_info_list); + kernel::AICPUQuery(kernel_node, &kernel_info_list); select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); } diff --git a/mindspore/ccsrc/kernel/kernel_query.cc b/mindspore/ccsrc/kernel/kernel_query.cc index c1bfafc38..0f53a5a2c 100755 --- a/mindspore/ccsrc/kernel/kernel_query.cc +++ b/mindspore/ccsrc/kernel/kernel_query.cc @@ -71,21 +71,20 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { +void AICPUQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_info_list); kernel_info_list->clear(); AicpuMetadataInfo(kernel_node, kernel_info_list); FilterInvalidKernelInfo(kernel_node, kernel_info_list); } -bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { +bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(select_kernel_build_info); std::vector> kernel_info_list; auto cnode = kernel_node->cast(); MS_EXCEPTION_IF_NULL(cnode); - AicpuMetadataInfo(cnode, &kernel_info_list); - FilterInvalidKernelInfo(cnode, &kernel_info_list); + AICPUQuery(cnode, &kernel_info_list); return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { MS_EXCEPTION_IF_NULL(item); @@ -93,7 +92,7 @@ bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr }); } -bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { +bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(select_kernel_build_info); std::vector> kernel_info_list; diff --git a/mindspore/ccsrc/kernel/kernel_query.h b/mindspore/ccsrc/kernel/kernel_query.h index 52ab01898..fe8696a91 100644 --- a/mindspore/ccsrc/kernel/kernel_query.h +++ b/mindspore/ccsrc/kernel/kernel_query.h @@ -26,9 +26,9 @@ namespace mindspore { namespace kernel { void KernelQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list); -void AICpuQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list); -bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); -bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); +void AICPUQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list); +bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); +bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc index 8b1a1548b..754425df1 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc @@ -559,6 +559,9 @@ bool IsShapeMatchFormat(const std::vector &shape, const std::string &for if (format == kOpFormat_DEFAULT) { return true; } + if (format == kOpFormat_NDHWC && shape.size() != kShape5dDims) { + return false; + } // if shape size is 0, the shape will be a scalar if (shape.empty()) { return true; @@ -574,21 +577,28 @@ bool IsShapeMatchFormat(const std::vector &shape, const std::string &for bool IsValidKernelInfo(const std::shared_ptr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { MS_EXCEPTION_IF_NULL(kernel_node); - auto check_function = [](const std::vector &shape, const std::string &format) -> bool { - if (!IsShapeMatchFormat(shape, format)) { - return false; - } - return true; - }; + const size_t kCAxis = 1; for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); - if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) { + if (kernel_build_info.GetOutputFormat(index) == kOpFormat_FRACTAL_Z_C04) { + if (output_shape.size() != kShape4dDims || output_shape[kCAxis] > 4) { + return false; + } + return false; + } + if (!IsShapeMatchFormat(output_shape, kernel_build_info.GetOutputFormat(index))) { return false; } } for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); - if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) { + if (!IsShapeMatchFormat(input_shape, kernel_build_info.GetInputFormat(index))) { + return false; + } + if (kernel_build_info.GetInputFormat(index) == kOpFormat_FRACTAL_Z_C04) { + if (input_shape.size() != kShape4dDims || input_shape[kCAxis] > 4) { + return false; + } return false; } } diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.h index 4c85468f1..3ce66b514 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.h +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.h @@ -20,12 +20,12 @@ #include #include #include +#include "kernel/oplib/opinfo.h" #include "kernel/kernel_build_info.h" namespace mindspore { namespace kernel { void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); -bool CheckSupported(const AnfNodePtr &anf_node, const KernelBuildInfoPtr &select_kernel_build_info); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc index 7e503ef34..05d36fd4f 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc @@ -32,13 +32,13 @@ namespace opt { using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; namespace { kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, - const AnfNodePtr &node, - const kernel::KernelBuildInfo ori_build_info) { + const AnfNodePtr &node, const TypeId device_type, + const kernel::KernelBuildInfo &ori_build_info) { KernelBuildInfoBuilder builder; builder.SetInputsFormat({input_format}); builder.SetOutputsFormat({output_format}); - builder.SetInputsDeviceType({ori_build_info.GetInputDeviceType(0)}); - builder.SetOutputsDeviceType({ori_build_info.GetOutputDeviceType(0)}); + builder.SetInputsDeviceType({device_type}); + builder.SetOutputsDeviceType({device_type}); builder.SetKernelType(ori_build_info.kernel_type()); builder.SetFusionType(ori_build_info.fusion_type()); builder.SetProcessor(ori_build_info.processor()); @@ -56,11 +56,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, CNodePtr trans_node = func_graph->NewCNode(trans_inputs); MS_EXCEPTION_IF_NULL(trans_node); std::vector padding_axis; - if (AnfAlgo::IsRealKernel(input)) { - padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); - } else { - padding_axis = AnfAlgo::GetPrevNodeOutputReshapeType(input, 0); - } + padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); if (need_padding) { // if need padding we should set the transdata node's shape to the padding shape AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, @@ -129,15 +125,8 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const KernelSelectPtr &kernel_select) { MS_EXCEPTION_IF_NULL(node); - std::string output_format; - std::vector origin_shape; - if (!AnfAlgo::IsRealKernel(node)) { - output_format = AnfAlgo::GetPrevNodeOutputFormat(node, 0); - origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); - } else { - output_format = AnfAlgo::GetOutputFormat(node, 0); - origin_shape = AnfAlgo::GetOutputInferShape(node, 0); - } + std::string output_format = AnfAlgo::GetOutputFormat(node, 0); + std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, 0); if (output_format == kOpFormat_NC1KHKWHWC0) { MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node " << node->DebugString(); @@ -186,6 +175,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt AnfNodePtr trans_node = nullptr; AnfNodePtr input_node = node; AnfNodePtr trans_data = nullptr; + TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0); MS_EXCEPTION_IF_NULL(node); if (origin_format.empty() || dest_format.empty()) { MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format; @@ -196,6 +186,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; } auto cnode = node->cast(); + dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index); MS_EXCEPTION_IF_NULL(cnode); input_node = AnfAlgo::GetInputNode(cnode, insert_index); } @@ -231,7 +222,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt MS_EXCEPTION_IF_NULL(trans_data); MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); - auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, *trans_ori_build_info); + auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, dtype, *trans_ori_build_info); AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get()); return trans_node; } diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h index 2f270b109..a5463131b 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h @@ -39,11 +39,11 @@ class SupportedChecker { virtual ~SupportedChecker() = default; virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) { - return kernel::IsSupportedByAiCore(anf_node, select_kernel_build_info); + return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info); } virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) { - return kernel::IsSupportedByAiCpu(anf_node, select_kernel_build_info); + return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info); } }; using SupportedCheckerPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc index fe9b35a5e..a3c87dad5 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc @@ -114,8 +114,8 @@ bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) { auto param_dtype = AnfAlgo::GetOutputDeviceDataType(final_node, 0); auto cast = trans_road[1]; - AnfAlgo::SetSelectKernelBuildInfo(GetKernelBuildInfo(cast, format, param_dtype, dtype), cast.get()); if (param_format == format && param_dtype != dtype) { + AnfAlgo::SetSelectKernelBuildInfo(GetKernelBuildInfo(cast, format, param_dtype, dtype), cast.get()); manager->Replace(trans_road[2], final_node); manager->Replace(cur_transop, cast); } diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 8b45ad7d3..7260bb46d 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -292,6 +292,9 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t << " is out of the node output range :" << GetOutputTensorNum(node) << " #node [" << node->DebugString() << "]"; } + if (!AnfAlgo::IsRealKernel(node)) { + return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx); + } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); @@ -311,6 +314,9 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i << " is out of the number node Input range :" << GetInputTensorNum(node) << "#node [" << node->DebugString() << "]"; } + if (!IsRealKernel(node)) { + GetPrevNodeOutputFormat(node, input_idx); + } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); @@ -367,8 +373,8 @@ std::vector AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &n } else if (b_shp->isa()) { return std::vector(); } else { - MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is " - << base_shape->ToString(); + MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx + << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString(); } } else if (base_shape->isa()) { return std::vector(); @@ -415,6 +421,9 @@ std::vector AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNode << " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node[" << node->DebugString() << "]"; } + if (!IsRealKernel(node)) { + return GetPrevNodeOutputReshapeType(node, input_idx); + } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); @@ -431,6 +440,9 @@ std::vector AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNod MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " << GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]"; } + if (!IsRealKernel(node)) { + return GetPrevNodeOutputReshapeType(node, output_idx); + } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); @@ -488,6 +500,9 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " << GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]"; } + if (!IsRealKernel(node)) { + return GetPrevNodeOutputDeviceDataType(node, output_idx); + } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); @@ -506,6 +521,9 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_ MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ " << GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]"; } + if (!IsRealKernel(node)) { + return GetPrevNodeOutputDeviceDataType(node, 0); + } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); -- GitLab