提交 2f5cbfc2 编写于 作者: Z zhoufeng

graph compile performance optimize

Signed-off-by: Nzhoufeng <zhoufeng54@huawei.com>
上级 a1b517b0
......@@ -31,12 +31,16 @@ namespace {
void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_info_list);
MS_EXCEPTION_IF_NULL(kernel_node);
size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(kernel_node);
size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel_node);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list;
(void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list),
[&kernel_node](const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) {
return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() &&
AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum();
});
(void)std::copy_if(
kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list),
[output_tensor_num, input_tensor_num](const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) {
return kernel_build_info->GetOutputNum() == output_tensor_num &&
kernel_build_info->GetInputNum() == input_tensor_num;
});
if (!filtered_list.empty()) {
kernel_info_list->clear();
(void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list));
......@@ -44,21 +48,20 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
MS_LOG(INFO) << "All kernel Info list does not match any kernel info ";
for (size_t index = 0; index < kernel_info_list->size(); ++index) {
std::ostringstream buffer;
auto kernel_info = kernel_info_list->at(index);
auto &kernel_info = kernel_info_list->at(index);
MS_EXCEPTION_IF_NULL(kernel_info);
if (AnfAlgo::GetOutputTensorNum(kernel_node) != kernel_info->GetOutputNum()) {
buffer << "Kernel node's output size [" << AnfAlgo::GetOutputTensorNum(kernel_node) << "]"
if (kernel_info->GetOutputNum() != output_tensor_num) {
buffer << "Kernel node's output size [" << output_tensor_num << "]"
<< " cannot match the kernel's output size [" << kernel_info->GetOutputNum() << "]";
} else {
buffer << "Kernel node's output size [" << AnfAlgo::GetInputTensorNum(kernel_node) << "]"
buffer << "Kernel node's output size [" << input_tensor_num << "]"
<< " cannot match the kernel's output size [" << kernel_info->GetInputNum() << "]";
}
MS_LOG(INFO) << "kernel [ " << index << " ] :" << kernel_info->ToString() << buffer.str();
}
kernel_info_list->clear();
MS_LOG(INFO) << "node" << kernel_node->DebugString() << "'s output size : ["
<< AnfAlgo::GetOutputTensorNum(kernel_node) << "]"
<< "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !";
MS_LOG(INFO) << "node" << kernel_node->DebugString() << "'s output size : [" << output_tensor_num << "]"
<< "input size : [" << input_tensor_num << "] cannot match any kernelInfo !";
}
}
} // namespace
......
......@@ -60,7 +60,7 @@ constexpr auto kFormat = "format";
constexpr auto kNeedCompile = "need_compile";
constexpr auto kShape = "shape";
constexpr auto kProcessor = "processor";
std::vector<std::shared_ptr<OpInfo>> OpLib::op_info_;
std::multimap<std::string, std::shared_ptr<OpInfo>> OpLib::op_info_;
static std::string ImplTypeToStr(OpImplyType impl_type) {
switch (impl_type) {
......@@ -133,11 +133,11 @@ void OpLib::DecodeAKGSpecificInfo(const nlohmann::json &obj, const std::shared_p
}
bool OpLib::RegOpFromLocalInfo() {
MS_LOG(INFO) << "Start";
static bool has_load = false;
if (has_load) {
return true;
}
MS_LOG(INFO) << "Start";
has_load = true;
std::string dir = common::GetEnv("MINDSPORE_OP_INFO_PATH");
if (dir.empty()) {
......@@ -224,7 +224,7 @@ bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpI
MS_LOG(ERROR) << "GetRefInfo Failed";
return false;
}
op_info_.push_back(op_info);
op_info_.emplace(op_info->op_name(), op_info);
return true;
}
......@@ -337,13 +337,16 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im
return nullptr;
}
std::string target_processor = is_gpu ? kCUDA : kAiCore;
for (const auto &op_info : op_info_) {
for (auto [iter, end] = op_info_.equal_range(op_name); iter != end; ++iter) {
auto &op_info = iter->second;
MS_EXCEPTION_IF_NULL(op_info);
if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) {
if (imply_type != kAKG || op_info->processor() == target_processor) {
return op_info;
}
if (op_info->imply_type() != imply_type) {
continue;
}
if (imply_type == kAKG && op_info->processor() != target_processor) {
continue;
}
return op_info;
}
MS_LOG(INFO) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
<< ", current op num: " << op_info_.size();
......@@ -376,7 +379,8 @@ bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo> &op_info) {
bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo> &op_info) {
MS_EXCEPTION_IF_NULL(op_info);
for (const auto &exist_op_info : op_info_) {
for (auto [iter, end] = op_info_.equal_range(op_info->op_name()); iter != end; ++iter) {
auto &exist_op_info = iter->second;
MS_EXCEPTION_IF_NULL(exist_op_info);
if (exist_op_info->equals_to(op_info)) {
return true;
......
......@@ -19,6 +19,7 @@
#include <vector>
#include <string>
#include <memory>
#include <map>
#include <nlohmann/json.hpp>
#include "utils/ms_utils.h"
#include "backend/kernel_compiler/oplib/opinfo.h"
......@@ -30,12 +31,12 @@ class OpLib {
OpLib() = default;
virtual ~OpLib() = default;
static bool RegOp(const std::string &json_string, const std::string &impl_path);
static void RegOpInfo(const std::shared_ptr<OpInfo> &opinfo) { op_info_.emplace_back(opinfo); }
static void RegOpInfo(const std::shared_ptr<OpInfo> &opinfo) { op_info_.emplace(opinfo->op_name(), opinfo); }
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, OpImplyType imply_type);
static const std::vector<std::shared_ptr<OpInfo>> &GetAllOpsInfo() { return op_info_; }
static const std::multimap<std::string, std::shared_ptr<OpInfo>> &GetAllOpsInfo() { return op_info_; }
protected:
static std::vector<std::shared_ptr<OpInfo>> op_info_;
static std::multimap<std::string, std::shared_ptr<OpInfo>> op_info_;
private:
static bool RegOpFromLocalInfo();
......
......@@ -32,7 +32,7 @@ class OpInfoLoaderPy {
auto ops = OpLib::GetAllOpsInfo();
auto op_infos = new std::vector<OpInfo *>();
for (auto op_info : ops) {
auto new_op_info = new OpInfo(*op_info);
auto new_op_info = new OpInfo(*op_info.second);
op_infos->emplace_back(new_op_info);
}
return (size_t)op_infos;
......
......@@ -71,8 +71,7 @@ static void AssignLabelForLabelSet(NotNull<std::shared_ptr<session::KernelGraph>
memo->insert(graph.get());
MS_LOG(INFO) << "Assign label for " << graph->ToString();
graph->SetExecOrderByDefault();
auto nodes = graph->execution_order();
const auto &nodes = graph->execution_order();
for (auto &node : nodes) {
if (!node->isa<CNode>()) {
......@@ -103,11 +102,7 @@ static void AssignLabelForGotoSwitch(NotNull<std::shared_ptr<session::KernelGrap
MS_LOG(INFO) << "Process label goto/switch for " << graph->ToString();
auto nodes = graph->execution_order();
auto end_goto = graph->get_end_goto();
if (end_goto != nullptr) {
nodes.push_back(end_goto);
}
const auto &nodes = graph->execution_order();
for (auto &node : nodes) {
if (!node->isa<CNode>()) {
continue;
......@@ -115,20 +110,18 @@ static void AssignLabelForGotoSwitch(NotNull<std::shared_ptr<session::KernelGrap
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
std::string node_name = AnfAlgo::GetCNodeName(node);
if (node_name == kLabelGotoOpName) {
if (IsPrimitiveCNode(cnode, prim::kPrimLabelGoto)) {
UpdateLabelGoto(NOT_NULL(cnode));
cnode->set_abstract(nullptr);
}
if (node_name == kLabelSwitchOpName) {
if (IsPrimitiveCNode(cnode, prim::kPrimLabelSwitch)) {
UpdateLabelSwitch(NOT_NULL(cnode));
}
}
for (auto &cg : graph->child_graph_order()) {
AssignLabelForGotoSwitch(NOT_NULL(cg), memo);
}
graph->SetExecOrderByDefault();
}
void AscendLabelAssign::AssignLabel(NotNull<std::shared_ptr<session::KernelGraph>> graph) {
......
......@@ -359,12 +359,23 @@ static inline uint64_t GetCurrentUSec() {
static uint64_t total_##stage = 0; \
static uint64_t count_##stage = 0;
#define PROF_LOCAL_DEFINE(stage) \
uint64_t total_##stage = 0; \
uint64_t count_##stage = 0;
#define PROF_MULTI_START(stage) uint64_t start_usec_##stage = mindspore::GetCurrentUSec()
#define PROF_MULTI_END(stage) \
++count_##stage; \
uint64_t end_usec_##stage = mindspore::GetCurrentUSec(); \
total_##stage += (end_usec_##stage - start_usec_##stage)
#define PROF_MULTI_END(stage) \
do { \
++count_##stage; \
uint64_t end_usec_##stage = mindspore::GetCurrentUSec(); \
total_##stage += (end_usec_##stage - start_usec_##stage); \
} while (0)
#define PROF_MULTI_PRINT(stage) \
do { \
MS_LOG(INFO) << #stage << " called " << count_##stage << " times, costs " << total_##stage << " usec."; \
} while (0)
} // namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_UTILS_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册