提交 b6ff2724 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1990 fix profiling stream id

Merge pull request !1990 from caifubi/fix-profiling-stream-id
graphengine @ c54db434
Subproject commit 45ca7863ac6410c8e2f83168481ddc6b43bcea33 Subproject commit c54db4343f83cb0c15cc3b5c9755926de27fa3af
...@@ -367,7 +367,8 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { ...@@ -367,7 +367,8 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
} }
if (ProfilingManager::GetInstance().IsProfiling()) { if (ProfilingManager::GetInstance().IsProfiling()) {
auto task_ids = ge::model_runner::ModelRunner::Instance().GetTaskIdList(model_iter->first); auto task_ids = ge::model_runner::ModelRunner::Instance().GetTaskIdList(model_iter->first);
ProfilingUtils::ReportProfilingData(task_ids, NOT_NULL(graph)); auto stream_ids = ge::model_runner::ModelRunner::Instance().GetStreamIdList(model_iter->first);
ProfilingUtils::ReportProfilingData(task_ids, stream_ids, NOT_NULL(graph));
} }
return true; return true;
} }
......
...@@ -302,7 +302,7 @@ bool ProfilingUtils::ValidComputeGraph(NotNull<const session::KernelGraph *> gra ...@@ -302,7 +302,7 @@ bool ProfilingUtils::ValidComputeGraph(NotNull<const session::KernelGraph *> gra
return false; return false;
} }
void ProfilingUtils::ReportProfilingData(const std::vector<uint32_t> &task_ids, void ProfilingUtils::ReportProfilingData(const std::vector<uint32_t> &task_ids, const std::vector<uint32_t> &stream_ids,
NotNull<const session::KernelGraph *> graph) { NotNull<const session::KernelGraph *> graph) {
if (!ValidComputeGraph(graph)) { if (!ValidComputeGraph(graph)) {
MS_LOG(WARNING) << "Not a valid compute graph:" << graph->graph_id(); MS_LOG(WARNING) << "Not a valid compute graph:" << graph->graph_id();
...@@ -319,6 +319,7 @@ void ProfilingUtils::ReportProfilingData(const std::vector<uint32_t> &task_ids, ...@@ -319,6 +319,7 @@ void ProfilingUtils::ReportProfilingData(const std::vector<uint32_t> &task_ids,
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
TaskDescReporter task_reporter(context->device_id(), "vm.task_desc_info", ret->second); TaskDescReporter task_reporter(context->device_id(), "vm.task_desc_info", ret->second);
task_reporter.set_task_ids(task_ids); task_reporter.set_task_ids(task_ids);
task_reporter.set_stream_ids(stream_ids);
task_reporter.ReportData(); task_reporter.ReportData();
GraphDescReporter graph_reporter(context->device_id(), "vm.graph_desc_info", ret->second); GraphDescReporter graph_reporter(context->device_id(), "vm.graph_desc_info", ret->second);
......
...@@ -87,7 +87,8 @@ class ProfilingUtils { ...@@ -87,7 +87,8 @@ class ProfilingUtils {
// Mapping task_id and kernel name for device to generate the time cost of specific kernel. // Mapping task_id and kernel name for device to generate the time cost of specific kernel.
// Device calculate the time cost of the task which is marked by task id. // Device calculate the time cost of the task which is marked by task id.
// But we need data of (kernel name , time cost) // But we need data of (kernel name , time cost)
static void ReportProfilingData(const std::vector<uint32_t> &task_ids, NotNull<const session::KernelGraph *> graph); static void ReportProfilingData(const std::vector<uint32_t> &task_ids, const std::vector<uint32_t> &stream_ids,
NotNull<const session::KernelGraph *> graph);
// Get profiling trace point from envs. // Get profiling trace point from envs.
// export PROFILING_FP_START='full name of the first cnode to execute' // export PROFILING_FP_START='full name of the first cnode to execute'
......
...@@ -40,12 +40,22 @@ void TaskDescReporter::ReportData() { ...@@ -40,12 +40,22 @@ void TaskDescReporter::ReportData() {
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod); auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(ascend_kernel_mod); MS_EXCEPTION_IF_NULL(ascend_kernel_mod);
auto desc_ptr = std::make_shared<TaskDesc>(node->fullname_with_scope(), task_ids_[task_index++], // Check task_id and stream_id valid
ascend_kernel_mod->block_dim(), ascend_kernel_mod->stream_id()); CheckStreamTaskValid(task_index, task_index);
auto desc_ptr = std::make_shared<TaskDesc>(node->fullname_with_scope(), task_ids_[task_index],
ascend_kernel_mod->block_dim(), stream_ids_[task_index]);
prof_desc_.emplace_back(desc_ptr); prof_desc_.emplace_back(desc_ptr);
++task_index;
} }
DescReporter::ReportData(); DescReporter::ReportData();
} }
void TaskDescReporter::CheckStreamTaskValid(uint32_t task_id, uint32_t stream_id) {
if (task_id >= task_ids_.size() || stream_id >= stream_ids_.size()) {
MS_LOG(EXCEPTION) << "Index invalid. task_id:" << task_id << ", task_ids.size:" << task_ids_.size()
<< ", stream_id:" << stream_id << ", stream_ids.size:" << stream_ids_.size();
}
}
} // namespace ascend } // namespace ascend
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore
...@@ -32,9 +32,12 @@ class TaskDescReporter : public DescReporter { ...@@ -32,9 +32,12 @@ class TaskDescReporter : public DescReporter {
~TaskDescReporter() override = default; ~TaskDescReporter() override = default;
void ReportData() override; void ReportData() override;
void set_task_ids(const std::vector<uint32_t> &task_ids) { task_ids_ = task_ids; } void set_task_ids(const std::vector<uint32_t> &task_ids) { task_ids_ = task_ids; }
void set_stream_ids(const std::vector<uint32_t> &stream_ids) { stream_ids_ = stream_ids; }
private: private:
std::vector<uint32_t> task_ids_; std::vector<uint32_t> task_ids_;
std::vector<uint32_t> stream_ids_;
void CheckStreamTaskValid(uint32_t task_id, uint32_t stream_id);
}; };
} // namespace ascend } // namespace ascend
} // namespace device } // namespace device
......
...@@ -40,6 +40,11 @@ const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const ...@@ -40,6 +40,11 @@ const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const
static std::vector<uint32_t> task_id_list; static std::vector<uint32_t> task_id_list;
return task_id_list; return task_id_list;
} }
const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const {
static std::vector<uint32_t> stream_id_list;
return stream_id_list;
}
} // namespace model_runner } // namespace model_runner
} // namespace ge } // namespace ge
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册