提交 c0070d3d 编写于 作者: Z Zhang Qinghua

Use the unified Execute function to run Graph or Single Op Graph.

上级 77dd91a6
......@@ -318,7 +318,7 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::
#endif
{
// run task on device
Execute(kernel_graph);
Execute(kernel_graph, true);
}
// summary
Summary(kernel_graph.get());
......@@ -348,17 +348,6 @@ void AscendSession::RunOpHardwareOptimize(const std::shared_ptr<session::KernelG
MS_LOG(INFO) << "Finish";
}
void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "Start!";
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
bool ret_ok = runtime_instance->LaunchKernel(kernel_graph.get());
if (!ret_ok) {
MS_LOG(EXCEPTION) << "Run task error!";
}
MS_LOG(INFO) << "Finish!";
}
bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const {
return run_op_graphs_.find(graph_info) != run_op_graphs_.end();
}
......@@ -398,7 +387,7 @@ void AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_i
// load input data to device
LoadInputData(graph, input_tensors);
// run op
RunOpExecTask(graph);
Execute(graph, false);
// get output
if (op_run_info.value != nullptr) {
std::vector<tensor::TensorPtr> pre_output_tensors;
......@@ -552,21 +541,30 @@ void AscendSession::RunOpMemoryClear(const KernelGraph *kernel_graph) const {
void AscendSession::Load(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "Start!";
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
(void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
bool ret_ok = runtime_instance->Load(kernel_graph.get());
bool ret_ok = runtime_instance->Load(kernel_graph.get(), is_task_sink);
if (!ret_ok) {
MS_LOG(EXCEPTION) << "Load task error!";
}
MS_LOG(INFO) << "Finish!";
}
void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const {
void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const {
MS_LOG(INFO) << "Start!";
bool is_task_sink = false;
if (is_task) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
}
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
bool ret_ok = runtime_instance->Run(kernel_graph.get());
bool ret_ok = runtime_instance->Run(kernel_graph.get(), is_task_sink);
if (!ret_ok) {
MS_LOG(EXCEPTION) << "run task error!";
}
......
......@@ -13,8 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H
#define MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H
#include <unordered_map>
#include <string>
#include <memory>
......@@ -82,13 +84,12 @@ class AscendSession : public SessionBasic {
KernelGraph *kernel_graph) const;
void RunOpMemoryClear(const KernelGraph *kernel_graph) const;
void Load(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const;
void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs);
void LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const;
// below functions are used for run op
void RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const;
void RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;
static void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs);
static void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
......
......@@ -118,7 +118,7 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
debugger_->PreExecute(kernel_graph);
}
#endif
bool ret = runtime_.Run(kernel_graph.get());
bool ret = runtime_.Run(kernel_graph.get(), false);
if (!ret) {
MS_LOG(EXCEPTION) << "Run graph failed";
}
......
......@@ -191,9 +191,9 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
#ifdef ENABLE_DEBUGGER
if (!runtime_instance->Run(kernel_graph.get(), debugger_.get())) {
if (!runtime_instance->Run(kernel_graph.get(), false, debugger_.get())) {
#else
if (!runtime_instance->Run(kernel_graph.get())) {
if (!runtime_instance->Run(kernel_graph.get(), false)) {
#endif
MS_LOG(EXCEPTION) << "GPU execute graph failed!";
}
......
......@@ -454,10 +454,7 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id);
}
bool AscendKernelRuntime::Load(session::KernelGraph *graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) {
if (!is_task_sink) {
return true;
}
......@@ -609,17 +606,14 @@ void AscendKernelRuntime::DebugTaskIdName(GraphId graph_id) {
}
}
bool AscendKernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) {
bool AscendKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger) {
bool ret = false;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
#if defined(_WIN32) || defined(_WIN64)
auto start_time = std::chrono::steady_clock::now();
#else
struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr);
#endif
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
if (is_task_sink) {
ret = RunTask(graph);
} else {
......
......@@ -44,8 +44,8 @@ class AscendKernelRuntime : public KernelRuntime {
bool GenTask(const session::KernelGraph *graph);
bool LoadTask(const session::KernelGraph *graph);
bool RunTask(const session::KernelGraph *graph);
bool Load(session::KernelGraph *graph) override;
bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
bool Load(session::KernelGraph *graph, bool is_task_sink) override;
bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override;
void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,
const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order) override;
......
......@@ -287,7 +287,7 @@ void CPUKernelRuntime::DecreaseSummaryRefCount(const session::NamedSummaryOutput
resource_manager_.DecreaseSummaryRefCount(summary_outputs);
}
bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph, Debugger *debugger) {
bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph, bool is_task_sink, Debugger *debugger) {
MS_EXCEPTION_IF_NULL(kernel_graph);
resource_manager_.IncreaseAddressRefCount(kernel_graph);
......
......@@ -36,7 +36,7 @@ class CPUKernelRuntime : public KernelRuntime {
~CPUKernelRuntime() override = default;
bool Init() override { return true; }
bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override;
void AssignKernelAddress(session::KernelGraph *kernel_graph);
void BindInputOutput(session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs);
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_
#include <string>
#include <memory>
#include <vector>
#include <set>
#include <utility>
#include <unordered_map>
#include <unordered_set>
#include "runtime/device/kernel_runtime.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "backend/optimizer/mem_reuse/mem_swap_manager.h"
namespace mindspore {
namespace device {
namespace gpu {
using mindspore::device::memswap::MemSwapManagerPtr;
class GPUKernelRuntime : public KernelRuntime {
public:
GPUKernelRuntime() = default;
~GPUKernelRuntime() override = default;
bool Init() override;
void ReleaseDeviceRes() override;
void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,
const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order) override;
void AssignMemory(session::KernelGraph *graph) override;
bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
#ifdef ENABLE_DUMP_E2E
bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
#endif
protected:
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override;
bool SyncStream() override;
private:
GPUKernelRuntime(const GPUKernelRuntime &);
GPUKernelRuntime &operator=(const GPUKernelRuntime &);
bool InitDevice();
bool device_init_{false};
// The related functions and members for using dynamic memory pool.
void InitKernelRefCount(const session::KernelGraph *graph);
void InitKernelOutputAddress(const session::KernelGraph *graph);
void InitKernelWorkspaceAddress(const session::KernelGraph *graph);
void InitMemorySwapInfo(const session::KernelGraph *graph);
void SaveGraphOutputNode(const session::KernelGraph *graph);
bool IsGraphOutput(const session::KernelGraph *graph, const mindspore::AnfNodePtr &kernel) const;
void ClearKernelOutputAddress(const session::KernelGraph *graph);
void ClearKernelWorkspaceAddress(const session::KernelGraph *graph);
void ClearKernelOldOutputAndWorkspace(const session::KernelGraph *graph);
bool RunOneStep(const session::KernelGraph *graph, Debugger *debugger = nullptr);
bool SearchMemSwapScheme(const session::KernelGraph *graph, Debugger *debugger = nullptr);
bool RefineMemSwapScheme(const session::KernelGraph *graph, Debugger *debugger = nullptr);
bool LaunchKernelDynamic(const session::KernelGraph *graph, Debugger *debugger = nullptr, bool mock = false,
bool profiling = false);
void LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs,
const AddressPtrList &workspace, const AddressPtrList &outputs);
bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size, bool mock);
bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
AddressPtrList *kernel_outputs, bool mock);
bool AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, bool mock);
bool AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
AddressPtrList *kernel_outputs, bool mock);
bool AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod,
const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_workspaces,
bool mock);
void AllocCommunicationOpDynamicRes(const session::KernelGraph *graph);
void AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel);
void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel);
void AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory,
const DeviceAddressPtrList addr_list, size_t total_size,
std::vector<size_t> size_list);
void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel);
bool UpdateMemorySwapTask(const AnfNodePtr &kernel, bool mock, bool profiling);
bool AddMemorySwapTask(const AnfNodePtr &kernel, bool mock, bool profiling);
void UpdateHostSwapInQueue(const DeviceAddressPtr device_address, bool mock);
void UpdateHostSwapOutQueue(bool mock);
void ClearSwapInfo(bool mock);
std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_;
std::unordered_map<uint32_t, MemSwapManagerPtr> mem_swap_map_;
std::unordered_map<uint32_t, bool> is_first_step_map_;
std::unordered_map<uint32_t, std::set<AnfNodePtr>> graph_output_map_;
MemReuseUtilPtr mem_reuse_util_{nullptr};
MemSwapManagerPtr mem_swap_manager_{nullptr};
};
MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime);
} // namespace gpu
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_
#include <string>
#include <memory>
#include <vector>
#include <set>
#include <utility>
#include <unordered_map>
#include <unordered_set>
#include "runtime/device/kernel_runtime.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "backend/optimizer/mem_reuse/mem_swap_manager.h"
namespace mindspore {
namespace device {
namespace gpu {
using mindspore::device::memswap::MemSwapManagerPtr;
class GPUKernelRuntime : public KernelRuntime {
public:
GPUKernelRuntime() = default;
~GPUKernelRuntime() override = default;
bool Init() override;
void ReleaseDeviceRes() override;
void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,
const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order) override;
void AssignMemory(session::KernelGraph *graph) override;
bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override;
#ifdef ENABLE_DUMP_E2E
bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
#endif
protected:
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override;
bool SyncStream() override;
private:
GPUKernelRuntime(const GPUKernelRuntime &);
GPUKernelRuntime &operator=(const GPUKernelRuntime &);
bool InitDevice();
bool device_init_{false};
// The related functions and members for using dynamic memory pool.
void InitKernelRefCount(const session::KernelGraph *graph);
void InitKernelOutputAddress(const session::KernelGraph *graph);
void InitKernelWorkspaceAddress(const session::KernelGraph *graph);
void InitMemorySwapInfo(const session::KernelGraph *graph);
void SaveGraphOutputNode(const session::KernelGraph *graph);
bool IsGraphOutput(const session::KernelGraph *graph, const mindspore::AnfNodePtr &kernel) const;
void ClearKernelOutputAddress(const session::KernelGraph *graph);
void ClearKernelWorkspaceAddress(const session::KernelGraph *graph);
void ClearKernelOldOutputAndWorkspace(const session::KernelGraph *graph);
bool RunOneStep(const session::KernelGraph *graph, Debugger *debugger = nullptr);
bool SearchMemSwapScheme(const session::KernelGraph *graph, Debugger *debugger = nullptr);
bool RefineMemSwapScheme(const session::KernelGraph *graph, Debugger *debugger = nullptr);
bool LaunchKernelDynamic(const session::KernelGraph *graph, Debugger *debugger = nullptr, bool mock = false,
bool profiling = false);
void LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs,
const AddressPtrList &workspace, const AddressPtrList &outputs);
bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size, bool mock);
bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
AddressPtrList *kernel_outputs, bool mock);
bool AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, bool mock);
bool AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
AddressPtrList *kernel_outputs, bool mock);
bool AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod,
const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_workspaces,
bool mock);
void AllocCommunicationOpDynamicRes(const session::KernelGraph *graph);
void AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel);
void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel);
void AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory,
const DeviceAddressPtrList addr_list, size_t total_size,
std::vector<size_t> size_list);
void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel);
bool UpdateMemorySwapTask(const AnfNodePtr &kernel, bool mock, bool profiling);
bool AddMemorySwapTask(const AnfNodePtr &kernel, bool mock, bool profiling);
void UpdateHostSwapInQueue(const DeviceAddressPtr device_address, bool mock);
void UpdateHostSwapOutQueue(bool mock);
void ClearSwapInfo(bool mock);
std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_;
std::unordered_map<uint32_t, MemSwapManagerPtr> mem_swap_map_;
std::unordered_map<uint32_t, bool> is_first_step_map_;
std::unordered_map<uint32_t, std::set<AnfNodePtr>> graph_output_map_;
MemReuseUtilPtr mem_reuse_util_{nullptr};
MemSwapManagerPtr mem_swap_manager_{nullptr};
};
MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime);
} // namespace gpu
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_
......@@ -40,7 +40,7 @@ KernelRuntime::~KernelRuntime() {
#endif
}
bool KernelRuntime::Load(session::KernelGraph *graph) { return true; }
bool KernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { return true; }
bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *debugger) {
if (graph != nullptr) {
......
......@@ -59,8 +59,8 @@ class KernelRuntime {
bool DumpDataEnabled();
bool DumpDataEnabledIteration();
virtual bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr);
virtual bool Load(session::KernelGraph *graph);
virtual bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) = 0;
virtual bool Load(session::KernelGraph *graph, bool is_task_sink);
virtual bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) = 0;
bool LaunchKernel(const session::KernelGraph *graph);
bool LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, const AddressPtrList &kernel_inputs,
const AddressPtrList &kernel_outputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册