提交 44df45c8 编写于 作者: W WilliamLian

add attr op_pattern to kernel build info

上级 074a2f34
......@@ -425,7 +425,7 @@ std::shared_ptr<kernel::KernelBuildInfo> ChooseMatchedKernelInfo(
return kernel_info_list[selected_index];
}
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> GetAllMatchedFilteredKernelInfo(
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByDtype(
const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> result;
for (const auto &kernel_build_info : kernel_info_list) {
......@@ -474,7 +474,7 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr;
// Matched kernel info
// Filter kernel info matched with me infered type
auto filtered_kernel_info_list = GetAllMatchedFilteredKernelInfo(kernel_node, kernel_info_list);
auto filtered_kernel_info_list = FilteredKernelInfoByDtype(kernel_node, kernel_info_list);
if (!filtered_kernel_info_list.empty()) {
selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
select_status = kStatusAllMatched;
......@@ -508,6 +508,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
<< "] cannot find valid TBE kernel info, try to get aicpu kernel info";
kernel::AICpuQuery(kernel_node, &kernel_info_list);
select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node);
}
// The kernel info not finded both in the aicpu kernel list & aicore kernel list
if (select_status == kNoMatched) {
......
......@@ -47,6 +47,13 @@ enum FusionType {
OPAQUE,
UNKNOWN_FUSION_TYPE = -1,
};
enum OpPattern {
kCommonPattern = 0,
kFormatAgnosticPattern = 1,
kBroadcastPattern = 2,
kReducePattern = 3,
kDynamicFormatPattern = 4,
};
// Backend processor
enum Processor {
......
......@@ -162,5 +162,10 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(
MS_EXCEPTION_IF_NULL(kernel_build_info_);
kernel_build_info_->output_reshape_type_ = output_reshape_type;
}
void KernelBuildInfo::KernelBuildInfoBuilder::SetOpPattern(OpPattern pattern) {
MS_EXCEPTION_IF_NULL(kernel_build_info_);
kernel_build_info_->op_pattern_ = pattern;
}
} // namespace kernel
} // namespace mindspore
......@@ -34,6 +34,7 @@ class KernelBuildInfo {
kernel_type_ = AUTO_DIFF_KERNEL;
fusion_type_ = OPAQUE;
processor_ = AICORE;
op_pattern_ = kCommonPattern;
input_reshape_type_ = {};
output_reshape_type_ = {};
inputs_format_ = {};
......@@ -70,6 +71,8 @@ class KernelBuildInfo {
std::vector<TypeId> GetAllOutputDeviceTypes() const;
OpPattern op_pattern() const { return op_pattern_; }
FusionType fusion_type() const { return fusion_type_; }
Processor processor() const { return processor_; }
......@@ -88,6 +91,7 @@ class KernelBuildInfo {
private:
KernelType kernel_type_;
std::vector<std::string> inputs_format_;
OpPattern op_pattern_;
std::vector<std::string> outputs_format_;
std::vector<std::vector<Axis>> input_reshape_type_;
std::vector<std::vector<Axis>> output_reshape_type_;
......@@ -125,6 +129,8 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
void SetProcessor(Processor processor);
void SetOpPattern(OpPattern pattern);
std::shared_ptr<KernelBuildInfo> Build();
private:
......
......@@ -40,7 +40,7 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
(void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list));
} else {
MS_LOG(WARNING) << "All kernel Info list does not match any kernel info ";
for (size_t index; index < kernel_info_list->size(); ++index) {
for (size_t index = 0; index < kernel_info_list->size(); ++index) {
MS_EXCEPTION_IF_NULL(kernel_info_list->at(index));
MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString();
}
......
......@@ -21,6 +21,7 @@
#include <memory>
#include <unordered_map>
#include "ir/dtype.h"
#include "kernel/kernel.h"
namespace mindspore {
namespace kernel {
......@@ -100,7 +101,7 @@ class OpInfo {
std::string kernel_name() const { return kernel_name_; }
bool partial_flag() const { return partial_flag_; }
bool dynamic_format() const { return dynamic_format_; }
std::string op_pattern() const { return op_pattern_; }
OpPattern op_pattern() const { return op_pattern_; }
std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; }
......@@ -116,7 +117,7 @@ class OpInfo {
void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; }
void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; }
void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; }
void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; }
void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; }
void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); }
void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); }
void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); }
......@@ -137,7 +138,7 @@ class OpInfo {
std::string kernel_name_;
bool partial_flag_ = false;
bool dynamic_format_ = false;
std::string op_pattern_;
OpPattern op_pattern_ = kCommonPattern;
std::vector<std::shared_ptr<OpAttr>> attrs_ptr_;
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr_;
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr_;
......
......@@ -18,6 +18,7 @@
#include <pybind11/pybind11.h>
#include <unordered_map>
#include <memory>
#include <map>
#include "utils/log_adapter.h"
#include "utils/overload.h"
#include "utils/context/ms_context.h"
......@@ -35,6 +36,9 @@ constexpr auto kPartialFlag = "partial_flag";
constexpr auto kReshapeType = "reshape_type";
constexpr auto kOpPattern = "op_pattern";
constexpr auto kDynamicFormat = "dynamic_format";
constexpr auto kFormatAgnostic = "formatAgnostic";
constexpr auto kBroadcast = "broadcast";
constexpr auto kReduce = "reduce";
constexpr auto kDtypeFormat = "dtype_format";
constexpr auto kAttr = "attr";
constexpr auto kIputs = "inputs";
......@@ -95,13 +99,19 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path)
}
void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) {
const std::map<std::string, kernel::OpPattern> kOpPatternMap = {{kFormatAgnostic, kFormatAgnosticPattern},
{kFormatAgnostic, kBroadcastPattern},
{kReduce, kReducePattern},
{kDynamicFormat, kDynamicFormatPattern}};
op_info->set_async_flag(obj.at(kAsyncFlag));
op_info->set_binfile_name(obj.at(kBinfileName));
op_info->set_compute_cost(obj.at(kComputeCost));
op_info->set_kernel_name(obj.at(kKernelName));
op_info->set_partial_flag(obj.at(kPartialFlag));
if (obj.find(kOpPattern) != obj.end()) {
op_info->set_op_pattern(obj.at(kOpPattern));
if (kOpPatternMap.find(obj.at(kOpPattern)) != kOpPatternMap.end()) {
op_info->set_op_pattern(obj.at(kOpPattern));
}
}
if (obj.find(kDynamicFormat) != obj.end()) {
op_info->set_dynamic_format(obj.at(kDynamicFormat));
......
......@@ -492,6 +492,7 @@ void SetKernelBuildCommonInfo(const std::shared_ptr<KernelBuildInfo::KernelBuild
if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) {
builder->SetFusionType(tbe::GetFusionType(fusion_type));
}
builder->SetOpPattern(op_info_ptr->op_pattern());
builder->SetKernelType(TBE_KERNEL);
}
......@@ -509,7 +510,7 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn
if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr("dyn_input_sizes"));
}
if (inputs.size() > 0) {
if (!inputs.empty()) {
MS_EXCEPTION_IF_NULL(inputs[0]);
size_t kernel_info_cnt = inputs[0]->dtypes().size();
for (size_t j = 0; j < kernel_info_cnt; j++) {
......@@ -624,21 +625,17 @@ void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<Ke
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
for (auto parse_info : parse_info_list) {
if (context_ptr->execution_mode() == kPynativeMode) {
kernel_info_list->push_back(parse_info);
} else {
if (IsValidKernelInfo(kernel_node, *(parse_info))) {
if (CheckSupported(kernel_node, parse_info)) {
kernel_info_list->push_back(parse_info);
} else {
MS_LOG(INFO) << "CheckSupported Failed for TBE op" << op_name << " kernel info.";
}
for (const auto &parse_info : parse_info_list) {
if (IsValidKernelInfo(kernel_node, *(parse_info))) {
if (CheckSupported(kernel_node, parse_info)) {
kernel_info_list->push_back(parse_info);
} else {
MS_LOG(INFO) << "CheckSupported Failed for TBE op" << op_name << " kernel info.";
}
}
}
if (kernel_info_list->empty()) {
MS_LOG(DEBUG) << "Tbe dose not have op [" << op_name << "].";
if (kernel_info_list->empty()) {
MS_LOG(DEBUG) << "Tbe dose not have op [" << op_name << "].";
}
}
}
} // namespace kernel
......
......@@ -44,6 +44,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_builder_info);
builder->SetKernelType(AICPU_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), node);
} else {
MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node ["
<< node->DebugString() << "]";
......
......@@ -657,6 +657,16 @@ void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_
to_node->set_abstract(from_node->abstract());
}
kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
// select_kernel_build_info() has checked whether return pointer is null
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
return build_info->op_pattern();
}
// get KernelBuildType of node, such as ATT,RT,FWK and so on
KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
......
......@@ -138,6 +138,8 @@ class AnfRuntimeAlgorithm {
static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
const std::vector<std::vector<size_t>> &shapes, AnfNode *node);
static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node);
// get op pattern of the node
static kernel::OpPattern GetOpPattern(const AnfNodePtr &node);
// get KernelBuildType of node ,such as ATT,RT,FWK and so on
static KernelType GetKernelType(const AnfNodePtr &node);
// get processor type:AICORE,AICPU...
......
......@@ -142,6 +142,7 @@ constexpr auto kLabelGotoOpName = "LabelGoto";
// attr key name
constexpr auto kAttrInputNames = "input_names";
constexpr auto kAttrIsAICPUKernel = "is_ai_cpu_kernel";
constexpr auto kIsBackendCast = "is_backed_cast";
constexpr auto kAttrOutputNames = "output_names";
constexpr auto kAttrVisited = "visited";
......@@ -215,10 +216,11 @@ constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ";
constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0";
constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04";
constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04";
const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND,
kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0,
kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
constexpr auto kOpFormat_NDHWC = "NDHWC";
const std::set<std::string> kOpFormatList = {
kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC,
kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ,
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC};
const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN};
const std::set<std::string> kOptOperatorSet = {
kMomentumOpName, kApplyMomentumOpName, kApplyAdadeltaOpName,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册