提交 5ef6c082 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5756 Move ascend dependent functions to ascend kernel runtime.

Merge pull request !5756 from 张清华/master
......@@ -178,10 +178,8 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
#endif
// alloc mem
MemoryAlloc(root_graph.get());
// task generate
GenerateTaskInfo(root_graph);
// load task into device
LoadTask(root_graph);
// generate and load task into device
Load(root_graph);
DumpAllGraphs(all_graphs);
// return the root_graph id to backend
auto graph_id = root_graph->graph_id();
......@@ -258,10 +256,8 @@ void AscendSession::BuildGraph(GraphId graph_id) {
} else {
// alloc memory, including static memory and dynamic memory
MemoryAlloc(graph.get());
// generate task info for task sink mode
GenerateTaskInfo(graph);
// load task info to device if it is sink mode
LoadTask(graph);
// generate and load task info to device if it is sink mode
Load(graph);
}
// sync the inital const tensor to device
SyncInitialTenosrToDevice();
......@@ -322,7 +318,7 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::
#endif
{
// run task on device
ExecTask(kernel_graph);
Execute(kernel_graph);
}
// summary
Summary(kernel_graph.get());
......@@ -554,30 +550,19 @@ void AscendSession::RunOpMemoryClear(const KernelGraph *kernel_graph) const {
runtime_instance->RunOpClearMemory(kernel_graph);
}
void AscendSession::GenerateTaskInfo(const std::shared_ptr<KernelGraph> &kernel_graph) const {
void AscendSession::Load(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "Start!";
(void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
bool ret_ok = runtime_instance->GenTask(kernel_graph.get());
if (!ret_ok) {
MS_LOG(EXCEPTION) << "Generate task error!";
}
MS_LOG(INFO) << "Finish!";
}
void AscendSession::LoadTask(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "Start!";
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
bool ret_ok = runtime_instance->LoadTask(kernel_graph.get());
bool ret_ok = runtime_instance->Load(kernel_graph.get());
if (!ret_ok) {
MS_LOG(EXCEPTION) << "Load task error!";
}
MS_LOG(INFO) << "Finish!";
}
void AscendSession::ExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const {
void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "Start!";
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
......
......@@ -81,9 +81,8 @@ class AscendSession : public SessionBasic {
void RunOpMemoryAlloc(const ValuePtr &pre_output_value, const std::vector<tensor::TensorPtr> &input_tensors,
KernelGraph *kernel_graph) const;
void RunOpMemoryClear(const KernelGraph *kernel_graph) const;
void GenerateTaskInfo(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void LoadTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void ExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void Load(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs);
void LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const;
......
......@@ -454,19 +454,31 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id);
}
bool AscendKernelRuntime::Load(session::KernelGraph *graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
if (!is_task_sink) {
return true;
}
if (!GenTask(graph)) {
return false;
}
if (!LoadTask(graph)) {
return false;
}
return true;
}
bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
SetContext();
if (graph == nullptr) {
MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!";
}
MS_LOG(INFO) << "GenTask start. GraphId:" << graph->graph_id();
#ifdef MEM_REUSE_DEBUG
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
if (!is_task_sink) {
return true;
}
#ifdef MEM_REUSE_DEBUG
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_REUSE)) {
// Get normal graph ir for memreuse
mindspore::memreuse::MemReuseChecker::GetInstance().CheckNormalIR(graph);
......@@ -517,13 +529,6 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
MS_EXCEPTION(NotExistsError) << "Null pointer graph, LoadTask failed. ";
}
MS_LOG(INFO) << "LoadTask start. GraphId:" << graph->graph_id();
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
if (!is_task_sink) {
return true;
}
if (GraphWithEmptyTaskList(graph)) {
MS_LOG(WARNING) << "LoadTask end, task list is empty";
return true;
......@@ -604,6 +609,36 @@ void AscendKernelRuntime::DebugTaskIdName(GraphId graph_id) {
}
}
bool AscendKernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) {
bool ret = false;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
#if defined(_WIN32) || defined(_WIN64)
auto start_time = std::chrono::steady_clock::now();
#else
struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr);
#endif
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
if (is_task_sink) {
ret = RunTask(graph);
} else {
ret = LaunchKernel(graph);
}
#if defined(_WIN32) || defined(_WIN64)
auto end_time = std::chrono::steady_clock::now();
std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time;
MS_LOG(INFO) << "Call MS Run Success in " << cost.count() << " us";
#else
(void)gettimeofday(&end_time, nullptr);
const uint64_t kUSecondInSecond = 1000000;
uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
MS_LOG(INFO) << "Call MS Run Success in " << cost << " us";
#endif
return ret;
}
bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
SetContext();
MS_EXCEPTION_IF_NULL(graph);
......
......@@ -40,10 +40,12 @@ class AscendKernelRuntime : public KernelRuntime {
~AscendKernelRuntime() override;
bool Init() override;
bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
bool LoadData(session::KernelGraph *graph, Debugger *debugger) override;
bool GenTask(const session::KernelGraph *graph) override;
bool RunTask(const session::KernelGraph *graph) override;
bool LoadTask(const session::KernelGraph *graph) override;
bool LoadData(session::KernelGraph *graph, Debugger *debugger);
bool GenTask(const session::KernelGraph *graph);
bool LoadTask(const session::KernelGraph *graph);
bool RunTask(const session::KernelGraph *graph);
bool Load(session::KernelGraph *graph) override;
bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,
const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order) override;
......
......@@ -40,37 +40,8 @@ KernelRuntime::~KernelRuntime() {
#endif
}
bool KernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) {
bool ret = false;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
#if defined(_WIN32) || defined(_WIN64)
auto start_time = std::chrono::steady_clock::now();
#else
struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr);
#endif
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
if (is_task_sink) {
ret = RunTask(graph);
} else {
ret = LaunchKernel(graph);
}
#if defined(_WIN32) || defined(_WIN64)
auto end_time = std::chrono::steady_clock::now();
std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time;
MS_LOG(INFO) << "Call MS Run Success in " << cost.count() << " us";
#else
(void)gettimeofday(&end_time, nullptr);
const uint64_t kUSecondInSecond = 1000000;
uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
MS_LOG(INFO) << "Call MS Run Success in " << cost << " us";
#endif
return ret;
}
bool KernelRuntime::Load(session::KernelGraph *graph) { return true; }
// for D to impl
bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *debugger) {
if (graph != nullptr) {
return true;
......@@ -78,37 +49,6 @@ bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *d
return false;
}
// for D to impl
bool KernelRuntime::LoadData(mindspore::session::KernelGraph *graph, Debugger *debugger) {
if (graph != nullptr) {
return true;
}
return false;
}
// for D to impl
bool KernelRuntime::GenTask(const session::KernelGraph *graph) {
if (graph != nullptr) {
return true;
}
return false;
}
bool KernelRuntime::LoadTask(const session::KernelGraph *graph) {
if (graph != nullptr) {
return true;
}
return false;
}
// for D to impl
bool KernelRuntime::RunTask(const session::KernelGraph *graph) {
if (graph != nullptr) {
return true;
}
return false;
}
bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
MS_EXCEPTION_IF_NULL(kernel);
if (AnfAlgo::OutputAddrExist(kernel, index)) {
......
......@@ -58,11 +58,9 @@ class KernelRuntime {
void RunOpClearMemory(const session::KernelGraph *graph);
bool DumpDataEnabled();
bool DumpDataEnabledIteration();
virtual bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr);
virtual bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr);
virtual bool LoadData(session::KernelGraph *graph, Debugger *debugger);
virtual bool RunTask(const session::KernelGraph *graph);
virtual bool GenTask(const session::KernelGraph *graph);
virtual bool Load(session::KernelGraph *graph);
virtual bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) = 0;
bool LaunchKernel(const session::KernelGraph *graph);
bool LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, const AddressPtrList &kernel_inputs,
const AddressPtrList &kernel_outputs,
......@@ -80,7 +78,6 @@ class KernelRuntime {
#ifdef ENABLE_DUMP_E2E
DumpConfPtr GetDumpConf();
#endif
virtual bool LoadTask(const session::KernelGraph *graph);
// for GPU and D to impl
virtual void ReleaseDeviceRes() {}
void set_device_id(uint32_t device_id) { device_id_ = device_id; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册