提交 b6ff2724 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1990 fix profiling stream id

Merge pull request !1990 from caifubi/fix-profiling-stream-id
graphengine @ c54db434
Subproject commit 45ca7863ac6410c8e2f83168481ddc6b43bcea33
Subproject commit c54db4343f83cb0c15cc3b5c9755926de27fa3af
......@@ -367,7 +367,8 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
}
if (ProfilingManager::GetInstance().IsProfiling()) {
auto task_ids = ge::model_runner::ModelRunner::Instance().GetTaskIdList(model_iter->first);
ProfilingUtils::ReportProfilingData(task_ids, NOT_NULL(graph));
auto stream_ids = ge::model_runner::ModelRunner::Instance().GetStreamIdList(model_iter->first);
ProfilingUtils::ReportProfilingData(task_ids, stream_ids, NOT_NULL(graph));
}
return true;
}
......
......@@ -302,7 +302,7 @@ bool ProfilingUtils::ValidComputeGraph(NotNull<const session::KernelGraph *> gra
return false;
}
void ProfilingUtils::ReportProfilingData(const std::vector<uint32_t> &task_ids,
void ProfilingUtils::ReportProfilingData(const std::vector<uint32_t> &task_ids, const std::vector<uint32_t> &stream_ids,
NotNull<const session::KernelGraph *> graph) {
if (!ValidComputeGraph(graph)) {
MS_LOG(WARNING) << "Not a valid compute graph:" << graph->graph_id();
......@@ -319,6 +319,7 @@ void ProfilingUtils::ReportProfilingData(const std::vector<uint32_t> &task_ids,
MS_EXCEPTION_IF_NULL(context);
TaskDescReporter task_reporter(context->device_id(), "vm.task_desc_info", ret->second);
task_reporter.set_task_ids(task_ids);
task_reporter.set_stream_ids(stream_ids);
task_reporter.ReportData();
GraphDescReporter graph_reporter(context->device_id(), "vm.graph_desc_info", ret->second);
......
......@@ -87,7 +87,8 @@ class ProfilingUtils {
// Mapping task_id and kernel name for device to generate the time cost of specific kernel.
// Device calculate the time cost of the task which is marked by task id.
// But we need data of (kernel name , time cost)
static void ReportProfilingData(const std::vector<uint32_t> &task_ids, NotNull<const session::KernelGraph *> graph);
static void ReportProfilingData(const std::vector<uint32_t> &task_ids, const std::vector<uint32_t> &stream_ids,
NotNull<const session::KernelGraph *> graph);
// Get profiling trace point from envs.
// export PROFILING_FP_START='full name of the first cnode to execute'
......
......@@ -40,12 +40,22 @@ void TaskDescReporter::ReportData() {
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(ascend_kernel_mod);
auto desc_ptr = std::make_shared<TaskDesc>(node->fullname_with_scope(), task_ids_[task_index++],
ascend_kernel_mod->block_dim(), ascend_kernel_mod->stream_id());
// Check task_id and stream_id valid
CheckStreamTaskValid(task_index, task_index);
auto desc_ptr = std::make_shared<TaskDesc>(node->fullname_with_scope(), task_ids_[task_index],
ascend_kernel_mod->block_dim(), stream_ids_[task_index]);
prof_desc_.emplace_back(desc_ptr);
++task_index;
}
DescReporter::ReportData();
}
void TaskDescReporter::CheckStreamTaskValid(uint32_t task_id, uint32_t stream_id) {
if (task_id >= task_ids_.size() || stream_id >= stream_ids_.size()) {
MS_LOG(EXCEPTION) << "Index invalid. task_id:" << task_id << ", task_ids.size:" << task_ids_.size()
<< ", stream_id:" << stream_id << ", stream_ids.size:" << stream_ids_.size();
}
}
} // namespace ascend
} // namespace device
} // namespace mindspore
......@@ -32,9 +32,12 @@ class TaskDescReporter : public DescReporter {
~TaskDescReporter() override = default;
void ReportData() override;
void set_task_ids(const std::vector<uint32_t> &task_ids) { task_ids_ = task_ids; }
void set_stream_ids(const std::vector<uint32_t> &stream_ids) { stream_ids_ = stream_ids; }
private:
std::vector<uint32_t> task_ids_;
std::vector<uint32_t> stream_ids_;
void CheckStreamTaskValid(uint32_t task_id, uint32_t stream_id);
};
} // namespace ascend
} // namespace device
......
......@@ -40,6 +40,11 @@ const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const
static std::vector<uint32_t> task_id_list;
return task_id_list;
}
const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const {
static std::vector<uint32_t> stream_id_list;
return stream_id_list;
}
} // namespace model_runner
} // namespace ge
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册