提交 52e97dbb 编写于 作者: W WilliamLian

using device dtype to create transdata kernel build info

上级 94883f9b
......@@ -506,7 +506,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
if (select_status == kNoMatched) {
MS_LOG(WARNING) << "The node [" << kernel_node->DebugString()
<< "] cannot find valid TBE kernel info, try to get aicpu kernel info";
kernel::AICpuQuery(kernel_node, &kernel_info_list);
kernel::AICPUQuery(kernel_node, &kernel_info_list);
select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node);
}
......
......@@ -71,21 +71,20 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
}
void AICpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
kernel_info_list->clear();
AicpuMetadataInfo(kernel_node, kernel_info_list);
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
}
bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(select_kernel_build_info);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
auto cnode = kernel_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
AicpuMetadataInfo(cnode, &kernel_info_list);
FilterInvalidKernelInfo(cnode, &kernel_info_list);
AICPUQuery(cnode, &kernel_info_list);
return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
[&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
MS_EXCEPTION_IF_NULL(item);
......@@ -93,7 +92,7 @@ bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr
});
}
bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(select_kernel_build_info);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
......
......@@ -26,9 +26,9 @@
namespace mindspore {
namespace kernel {
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
void AICpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);
bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);
void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);
bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_
......@@ -559,6 +559,9 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for
if (format == kOpFormat_DEFAULT) {
return true;
}
if (format == kOpFormat_NDHWC && shape.size() != kShape5dDims) {
return false;
}
// if shape size is 0, the shape will be a scalar
if (shape.empty()) {
return true;
......@@ -574,21 +577,28 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for
bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto check_function = [](const std::vector<size_t> &shape, const std::string &format) -> bool {
if (!IsShapeMatchFormat(shape, format)) {
return false;
}
return true;
};
const size_t kCAxis = 1;
for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) {
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index);
if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) {
if (kernel_build_info.GetOutputFormat(index) == kOpFormat_FRACTAL_Z_C04) {
if (output_shape.size() != kShape4dDims || output_shape[kCAxis] > 4) {
return false;
}
return false;
}
if (!IsShapeMatchFormat(output_shape, kernel_build_info.GetOutputFormat(index))) {
return false;
}
}
for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index);
if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) {
if (!IsShapeMatchFormat(input_shape, kernel_build_info.GetInputFormat(index))) {
return false;
}
if (kernel_build_info.GetInputFormat(index) == kOpFormat_FRACTAL_Z_C04) {
if (input_shape.size() != kShape4dDims || input_shape[kCAxis] > 4) {
return false;
}
return false;
}
}
......
......@@ -20,12 +20,12 @@
#include <string>
#include <vector>
#include <memory>
#include "kernel/oplib/opinfo.h"
#include "kernel/kernel_build_info.h"
namespace mindspore {
namespace kernel {
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
bool CheckSupported(const AnfNodePtr &anf_node, const KernelBuildInfoPtr &select_kernel_build_info);
} // namespace kernel
} // namespace mindspore
......
......@@ -32,13 +32,13 @@ namespace opt {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
namespace {
kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &node,
const kernel::KernelBuildInfo ori_build_info) {
const AnfNodePtr &node, const TypeId device_type,
const kernel::KernelBuildInfo &ori_build_info) {
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({input_format});
builder.SetOutputsFormat({output_format});
builder.SetInputsDeviceType({ori_build_info.GetInputDeviceType(0)});
builder.SetOutputsDeviceType({ori_build_info.GetOutputDeviceType(0)});
builder.SetInputsDeviceType({device_type});
builder.SetOutputsDeviceType({device_type});
builder.SetKernelType(ori_build_info.kernel_type());
builder.SetFusionType(ori_build_info.fusion_type());
builder.SetProcessor(ori_build_info.processor());
......@@ -56,11 +56,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
CNodePtr trans_node = func_graph->NewCNode(trans_inputs);
MS_EXCEPTION_IF_NULL(trans_node);
std::vector<kernel::Axis> padding_axis;
if (AnfAlgo::IsRealKernel(input)) {
padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
} else {
padding_axis = AnfAlgo::GetPrevNodeOutputReshapeType(input, 0);
}
padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
if (need_padding) {
// if need padding we should set the transdata node's shape to the padding shape
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
......@@ -129,15 +125,8 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select) {
MS_EXCEPTION_IF_NULL(node);
std::string output_format;
std::vector<size_t> origin_shape;
if (!AnfAlgo::IsRealKernel(node)) {
output_format = AnfAlgo::GetPrevNodeOutputFormat(node, 0);
origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
} else {
output_format = AnfAlgo::GetOutputFormat(node, 0);
origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
}
std::string output_format = AnfAlgo::GetOutputFormat(node, 0);
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
if (output_format == kOpFormat_NC1KHKWHWC0) {
MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node "
<< node->DebugString();
......@@ -186,6 +175,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
AnfNodePtr trans_node = nullptr;
AnfNodePtr input_node = node;
AnfNodePtr trans_data = nullptr;
TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0);
MS_EXCEPTION_IF_NULL(node);
if (origin_format.empty() || dest_format.empty()) {
MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format;
......@@ -196,6 +186,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode";
}
auto cnode = node->cast<CNodePtr>();
dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index);
MS_EXCEPTION_IF_NULL(cnode);
input_node = AnfAlgo::GetInputNode(cnode, insert_index);
}
......@@ -231,7 +222,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
MS_EXCEPTION_IF_NULL(trans_data);
MS_EXCEPTION_IF_NULL(trans_data->kernel_info());
auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info();
auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, *trans_ori_build_info);
auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, dtype, *trans_ori_build_info);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get());
return trans_node;
}
......
......@@ -39,11 +39,11 @@ class SupportedChecker {
virtual ~SupportedChecker() = default;
virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node,
const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
return kernel::IsSupportedByAiCore(anf_node, select_kernel_build_info);
return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info);
}
virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node,
const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
return kernel::IsSupportedByAiCpu(anf_node, select_kernel_build_info);
return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info);
}
};
using SupportedCheckerPtr = std::shared_ptr<SupportedChecker>;
......
......@@ -114,8 +114,8 @@ bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) {
auto param_dtype = AnfAlgo::GetOutputDeviceDataType(final_node, 0);
auto cast = trans_road[1];
AnfAlgo::SetSelectKernelBuildInfo(GetKernelBuildInfo(cast, format, param_dtype, dtype), cast.get());
if (param_format == format && param_dtype != dtype) {
AnfAlgo::SetSelectKernelBuildInfo(GetKernelBuildInfo(cast, format, param_dtype, dtype), cast.get());
manager->Replace(trans_road[2], final_node);
manager->Replace(cur_transop, cast);
}
......
......@@ -292,6 +292,9 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t
<< " is out of the node output range :" << GetOutputTensorNum(node) << " #node ["
<< node->DebugString() << "]";
}
if (!AnfAlgo::IsRealKernel(node)) {
return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx);
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
......@@ -311,6 +314,9 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
<< " is out of the number node Input range :" << GetInputTensorNum(node) << "#node ["
<< node->DebugString() << "]";
}
if (!IsRealKernel(node)) {
GetPrevNodeOutputFormat(node, input_idx);
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
......@@ -367,8 +373,8 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &n
} else if (b_shp->isa<abstract::NoShape>()) {
return std::vector<size_t>();
} else {
MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
<< base_shape->ToString();
MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx
<< " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString();
}
} else if (base_shape->isa<abstract::NoShape>()) {
return std::vector<size_t>();
......@@ -415,6 +421,9 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNode
<< " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node["
<< node->DebugString() << "]";
}
if (!IsRealKernel(node)) {
return GetPrevNodeOutputReshapeType(node, input_idx);
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
......@@ -431,6 +440,9 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNod
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
<< GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]";
}
if (!IsRealKernel(node)) {
return GetPrevNodeOutputReshapeType(node, output_idx);
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
......@@ -488,6 +500,9 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
<< GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]";
}
if (!IsRealKernel(node)) {
return GetPrevNodeOutputDeviceDataType(node, output_idx);
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
......@@ -506,6 +521,9 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_
MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ "
<< GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]";
}
if (!IsRealKernel(node)) {
return GetPrevNodeOutputDeviceDataType(node, 0);
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册