提交 52e97dbb 编写于 作者: W WilliamLian

using device dtype to create transdata kernel build info

上级 94883f9b
...@@ -506,7 +506,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) { ...@@ -506,7 +506,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
if (select_status == kNoMatched) { if (select_status == kNoMatched) {
MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() MS_LOG(WARNING) << "The node [" << kernel_node->DebugString()
<< "] cannot find valid TBE kernel info, try to get aicpu kernel info"; << "] cannot find valid TBE kernel info, try to get aicpu kernel info";
kernel::AICpuQuery(kernel_node, &kernel_info_list); kernel::AICPUQuery(kernel_node, &kernel_info_list);
select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node);
} }
......
...@@ -71,21 +71,20 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel ...@@ -71,21 +71,20 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
FilterInvalidKernelInfo(kernel_node, kernel_info_list); FilterInvalidKernelInfo(kernel_node, kernel_info_list);
} }
void AICpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list); MS_EXCEPTION_IF_NULL(kernel_info_list);
kernel_info_list->clear(); kernel_info_list->clear();
AicpuMetadataInfo(kernel_node, kernel_info_list); AicpuMetadataInfo(kernel_node, kernel_info_list);
FilterInvalidKernelInfo(kernel_node, kernel_info_list); FilterInvalidKernelInfo(kernel_node, kernel_info_list);
} }
bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(select_kernel_build_info); MS_EXCEPTION_IF_NULL(select_kernel_build_info);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
auto cnode = kernel_node->cast<CNodePtr>(); auto cnode = kernel_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
AicpuMetadataInfo(cnode, &kernel_info_list); AICPUQuery(cnode, &kernel_info_list);
FilterInvalidKernelInfo(cnode, &kernel_info_list);
return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
[&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
...@@ -93,7 +92,7 @@ bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr ...@@ -93,7 +92,7 @@ bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr
}); });
} }
bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(select_kernel_build_info); MS_EXCEPTION_IF_NULL(select_kernel_build_info);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
......
...@@ -26,9 +26,9 @@ ...@@ -26,9 +26,9 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
void AICpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);
bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ #endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_
...@@ -559,6 +559,9 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for ...@@ -559,6 +559,9 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for
if (format == kOpFormat_DEFAULT) { if (format == kOpFormat_DEFAULT) {
return true; return true;
} }
if (format == kOpFormat_NDHWC && shape.size() != kShape5dDims) {
return false;
}
// if shape size is 0, the shape will be a scalar // if shape size is 0, the shape will be a scalar
if (shape.empty()) { if (shape.empty()) {
return true; return true;
...@@ -574,21 +577,28 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for ...@@ -574,21 +577,28 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for
bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
auto check_function = [](const std::vector<size_t> &shape, const std::string &format) -> bool { const size_t kCAxis = 1;
if (!IsShapeMatchFormat(shape, format)) {
return false;
}
return true;
};
for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) {
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index);
if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) { if (kernel_build_info.GetOutputFormat(index) == kOpFormat_FRACTAL_Z_C04) {
if (output_shape.size() != kShape4dDims || output_shape[kCAxis] > 4) {
return false;
}
return false;
}
if (!IsShapeMatchFormat(output_shape, kernel_build_info.GetOutputFormat(index))) {
return false; return false;
} }
} }
for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index);
if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) { if (!IsShapeMatchFormat(input_shape, kernel_build_info.GetInputFormat(index))) {
return false;
}
if (kernel_build_info.GetInputFormat(index) == kOpFormat_FRACTAL_Z_C04) {
if (input_shape.size() != kShape4dDims || input_shape[kCAxis] > 4) {
return false;
}
return false; return false;
} }
} }
......
...@@ -20,12 +20,12 @@ ...@@ -20,12 +20,12 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "kernel/oplib/opinfo.h"
#include "kernel/kernel_build_info.h" #include "kernel/kernel_build_info.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list); void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
bool CheckSupported(const AnfNodePtr &anf_node, const KernelBuildInfoPtr &select_kernel_build_info);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
......
...@@ -32,13 +32,13 @@ namespace opt { ...@@ -32,13 +32,13 @@ namespace opt {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
namespace { namespace {
kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &node, const AnfNodePtr &node, const TypeId device_type,
const kernel::KernelBuildInfo ori_build_info) { const kernel::KernelBuildInfo &ori_build_info) {
KernelBuildInfoBuilder builder; KernelBuildInfoBuilder builder;
builder.SetInputsFormat({input_format}); builder.SetInputsFormat({input_format});
builder.SetOutputsFormat({output_format}); builder.SetOutputsFormat({output_format});
builder.SetInputsDeviceType({ori_build_info.GetInputDeviceType(0)}); builder.SetInputsDeviceType({device_type});
builder.SetOutputsDeviceType({ori_build_info.GetOutputDeviceType(0)}); builder.SetOutputsDeviceType({device_type});
builder.SetKernelType(ori_build_info.kernel_type()); builder.SetKernelType(ori_build_info.kernel_type());
builder.SetFusionType(ori_build_info.fusion_type()); builder.SetFusionType(ori_build_info.fusion_type());
builder.SetProcessor(ori_build_info.processor()); builder.SetProcessor(ori_build_info.processor());
...@@ -56,11 +56,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, ...@@ -56,11 +56,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
CNodePtr trans_node = func_graph->NewCNode(trans_inputs); CNodePtr trans_node = func_graph->NewCNode(trans_inputs);
MS_EXCEPTION_IF_NULL(trans_node); MS_EXCEPTION_IF_NULL(trans_node);
std::vector<kernel::Axis> padding_axis; std::vector<kernel::Axis> padding_axis;
if (AnfAlgo::IsRealKernel(input)) { padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
} else {
padding_axis = AnfAlgo::GetPrevNodeOutputReshapeType(input, 0);
}
if (need_padding) { if (need_padding) {
// if need padding we should set the transdata node's shape to the padding shape // if need padding we should set the transdata node's shape to the padding shape
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
...@@ -129,15 +125,8 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & ...@@ -129,15 +125,8 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select) { const KernelSelectPtr &kernel_select) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
std::string output_format; std::string output_format = AnfAlgo::GetOutputFormat(node, 0);
std::vector<size_t> origin_shape; std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
if (!AnfAlgo::IsRealKernel(node)) {
output_format = AnfAlgo::GetPrevNodeOutputFormat(node, 0);
origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
} else {
output_format = AnfAlgo::GetOutputFormat(node, 0);
origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
}
if (output_format == kOpFormat_NC1KHKWHWC0) { if (output_format == kOpFormat_NC1KHKWHWC0) {
MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node " MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node "
<< node->DebugString(); << node->DebugString();
...@@ -186,6 +175,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt ...@@ -186,6 +175,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
AnfNodePtr trans_node = nullptr; AnfNodePtr trans_node = nullptr;
AnfNodePtr input_node = node; AnfNodePtr input_node = node;
AnfNodePtr trans_data = nullptr; AnfNodePtr trans_data = nullptr;
TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (origin_format.empty() || dest_format.empty()) { if (origin_format.empty() || dest_format.empty()) {
MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format; MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format;
...@@ -196,6 +186,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt ...@@ -196,6 +186,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode";
} }
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index);
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
input_node = AnfAlgo::GetInputNode(cnode, insert_index); input_node = AnfAlgo::GetInputNode(cnode, insert_index);
} }
...@@ -231,7 +222,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt ...@@ -231,7 +222,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
MS_EXCEPTION_IF_NULL(trans_data); MS_EXCEPTION_IF_NULL(trans_data);
MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); MS_EXCEPTION_IF_NULL(trans_data->kernel_info());
auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info();
auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, *trans_ori_build_info); auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, dtype, *trans_ori_build_info);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get()); AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get());
return trans_node; return trans_node;
} }
......
...@@ -39,11 +39,11 @@ class SupportedChecker { ...@@ -39,11 +39,11 @@ class SupportedChecker {
virtual ~SupportedChecker() = default; virtual ~SupportedChecker() = default;
virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node, virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node,
const kernel::KernelBuildInfoPtr &select_kernel_build_info) { const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
return kernel::IsSupportedByAiCore(anf_node, select_kernel_build_info); return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info);
} }
virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node, virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node,
const kernel::KernelBuildInfoPtr &select_kernel_build_info) { const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
return kernel::IsSupportedByAiCpu(anf_node, select_kernel_build_info); return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info);
} }
}; };
using SupportedCheckerPtr = std::shared_ptr<SupportedChecker>; using SupportedCheckerPtr = std::shared_ptr<SupportedChecker>;
......
...@@ -114,8 +114,8 @@ bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) { ...@@ -114,8 +114,8 @@ bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) {
auto param_dtype = AnfAlgo::GetOutputDeviceDataType(final_node, 0); auto param_dtype = AnfAlgo::GetOutputDeviceDataType(final_node, 0);
auto cast = trans_road[1]; auto cast = trans_road[1];
AnfAlgo::SetSelectKernelBuildInfo(GetKernelBuildInfo(cast, format, param_dtype, dtype), cast.get());
if (param_format == format && param_dtype != dtype) { if (param_format == format && param_dtype != dtype) {
AnfAlgo::SetSelectKernelBuildInfo(GetKernelBuildInfo(cast, format, param_dtype, dtype), cast.get());
manager->Replace(trans_road[2], final_node); manager->Replace(trans_road[2], final_node);
manager->Replace(cur_transop, cast); manager->Replace(cur_transop, cast);
} }
......
...@@ -292,6 +292,9 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t ...@@ -292,6 +292,9 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t
<< " is out of the node output range :" << GetOutputTensorNum(node) << " #node [" << " is out of the node output range :" << GetOutputTensorNum(node) << " #node ["
<< node->DebugString() << "]"; << node->DebugString() << "]";
} }
if (!AnfAlgo::IsRealKernel(node)) {
return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx);
}
auto kernel_info = node->kernel_info(); auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info); MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info(); auto build_info = kernel_info->select_kernel_build_info();
...@@ -311,6 +314,9 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i ...@@ -311,6 +314,9 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
<< " is out of the number node Input range :" << GetInputTensorNum(node) << "#node [" << " is out of the number node Input range :" << GetInputTensorNum(node) << "#node ["
<< node->DebugString() << "]"; << node->DebugString() << "]";
} }
if (!IsRealKernel(node)) {
GetPrevNodeOutputFormat(node, input_idx);
}
auto kernel_info = node->kernel_info(); auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info); MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info(); auto build_info = kernel_info->select_kernel_build_info();
...@@ -367,8 +373,8 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &n ...@@ -367,8 +373,8 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &n
} else if (b_shp->isa<abstract::NoShape>()) { } else if (b_shp->isa<abstract::NoShape>()) {
return std::vector<size_t>(); return std::vector<size_t>();
} else { } else {
MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is " MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx
<< base_shape->ToString(); << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString();
} }
} else if (base_shape->isa<abstract::NoShape>()) { } else if (base_shape->isa<abstract::NoShape>()) {
return std::vector<size_t>(); return std::vector<size_t>();
...@@ -415,6 +421,9 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNode ...@@ -415,6 +421,9 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNode
<< " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node[" << " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node["
<< node->DebugString() << "]"; << node->DebugString() << "]";
} }
if (!IsRealKernel(node)) {
return GetPrevNodeOutputReshapeType(node, input_idx);
}
auto kernel_info = node->kernel_info(); auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info); MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info(); auto build_info = kernel_info->select_kernel_build_info();
...@@ -431,6 +440,9 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNod ...@@ -431,6 +440,9 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNod
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
<< GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]"; << GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]";
} }
if (!IsRealKernel(node)) {
return GetPrevNodeOutputReshapeType(node, output_idx);
}
auto kernel_info = node->kernel_info(); auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info); MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info(); auto build_info = kernel_info->select_kernel_build_info();
...@@ -488,6 +500,9 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size ...@@ -488,6 +500,9 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
<< GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]"; << GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]";
} }
if (!IsRealKernel(node)) {
return GetPrevNodeOutputDeviceDataType(node, output_idx);
}
auto kernel_info = node->kernel_info(); auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info); MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info(); auto build_info = kernel_info->select_kernel_build_info();
...@@ -506,6 +521,9 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_ ...@@ -506,6 +521,9 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_
MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ " MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ "
<< GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]"; << GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]";
} }
if (!IsRealKernel(node)) {
return GetPrevNodeOutputDeviceDataType(node, 0);
}
auto kernel_info = node->kernel_info(); auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info); MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info(); auto build_info = kernel_info->select_kernel_build_info();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册