提交 c0070d3d 编写于 作者: Z Zhang Qinghua

Use the unified Execute function to run Graph or Single Op Graph.

上级 77dd91a6
...@@ -318,7 +318,7 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor:: ...@@ -318,7 +318,7 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::
#endif #endif
{ {
// run task on device // run task on device
Execute(kernel_graph); Execute(kernel_graph, true);
} }
// summary // summary
Summary(kernel_graph.get()); Summary(kernel_graph.get());
...@@ -348,17 +348,6 @@ void AscendSession::RunOpHardwareOptimize(const std::shared_ptr<session::KernelG ...@@ -348,17 +348,6 @@ void AscendSession::RunOpHardwareOptimize(const std::shared_ptr<session::KernelG
MS_LOG(INFO) << "Finish"; MS_LOG(INFO) << "Finish";
} }
void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "Start!";
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
bool ret_ok = runtime_instance->LaunchKernel(kernel_graph.get());
if (!ret_ok) {
MS_LOG(EXCEPTION) << "Run task error!";
}
MS_LOG(INFO) << "Finish!";
}
bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const { bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const {
return run_op_graphs_.find(graph_info) != run_op_graphs_.end(); return run_op_graphs_.find(graph_info) != run_op_graphs_.end();
} }
...@@ -398,7 +387,7 @@ void AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_i ...@@ -398,7 +387,7 @@ void AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_i
// load input data to device // load input data to device
LoadInputData(graph, input_tensors); LoadInputData(graph, input_tensors);
// run op // run op
RunOpExecTask(graph); Execute(graph, false);
// get output // get output
if (op_run_info.value != nullptr) { if (op_run_info.value != nullptr) {
std::vector<tensor::TensorPtr> pre_output_tensors; std::vector<tensor::TensorPtr> pre_output_tensors;
...@@ -552,21 +541,30 @@ void AscendSession::RunOpMemoryClear(const KernelGraph *kernel_graph) const { ...@@ -552,21 +541,30 @@ void AscendSession::RunOpMemoryClear(const KernelGraph *kernel_graph) const {
void AscendSession::Load(const std::shared_ptr<KernelGraph> &kernel_graph) const { void AscendSession::Load(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "Start!"; MS_LOG(INFO) << "Start!";
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
(void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph); (void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance); MS_EXCEPTION_IF_NULL(runtime_instance);
bool ret_ok = runtime_instance->Load(kernel_graph.get()); bool ret_ok = runtime_instance->Load(kernel_graph.get(), is_task_sink);
if (!ret_ok) { if (!ret_ok) {
MS_LOG(EXCEPTION) << "Load task error!"; MS_LOG(EXCEPTION) << "Load task error!";
} }
MS_LOG(INFO) << "Finish!"; MS_LOG(INFO) << "Finish!";
} }
void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const { void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const {
MS_LOG(INFO) << "Start!"; MS_LOG(INFO) << "Start!";
bool is_task_sink = false;
if (is_task) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
}
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance); MS_EXCEPTION_IF_NULL(runtime_instance);
bool ret_ok = runtime_instance->Run(kernel_graph.get()); bool ret_ok = runtime_instance->Run(kernel_graph.get(), is_task_sink);
if (!ret_ok) { if (!ret_ok) {
MS_LOG(EXCEPTION) << "run task error!"; MS_LOG(EXCEPTION) << "run task error!";
} }
......
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H
#define MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H #define MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H
#include <unordered_map> #include <unordered_map>
#include <string> #include <string>
#include <memory> #include <memory>
...@@ -82,13 +84,12 @@ class AscendSession : public SessionBasic { ...@@ -82,13 +84,12 @@ class AscendSession : public SessionBasic {
KernelGraph *kernel_graph) const; KernelGraph *kernel_graph) const;
void RunOpMemoryClear(const KernelGraph *kernel_graph) const; void RunOpMemoryClear(const KernelGraph *kernel_graph) const;
void Load(const std::shared_ptr<KernelGraph> &kernel_graph) const; void Load(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const; void Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const;
void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const; void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs); void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs);
void LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const; void LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const;
// below functions are used for run op // below functions are used for run op
void RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const; void RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const;
void RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;
static void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs); static void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs);
static void LinkChildGraphs(NotNull<KernelGraphPtr> graph); static void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
......
...@@ -118,7 +118,7 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten ...@@ -118,7 +118,7 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
debugger_->PreExecute(kernel_graph); debugger_->PreExecute(kernel_graph);
} }
#endif #endif
bool ret = runtime_.Run(kernel_graph.get()); bool ret = runtime_.Run(kernel_graph.get(), false);
if (!ret) { if (!ret) {
MS_LOG(EXCEPTION) << "Run graph failed"; MS_LOG(EXCEPTION) << "Run graph failed";
} }
......
...@@ -191,9 +191,9 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const ...@@ -191,9 +191,9 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance); MS_EXCEPTION_IF_NULL(runtime_instance);
#ifdef ENABLE_DEBUGGER #ifdef ENABLE_DEBUGGER
if (!runtime_instance->Run(kernel_graph.get(), debugger_.get())) { if (!runtime_instance->Run(kernel_graph.get(), false, debugger_.get())) {
#else #else
if (!runtime_instance->Run(kernel_graph.get())) { if (!runtime_instance->Run(kernel_graph.get(), false)) {
#endif #endif
MS_LOG(EXCEPTION) << "GPU execute graph failed!"; MS_LOG(EXCEPTION) << "GPU execute graph failed!";
} }
......
...@@ -454,10 +454,7 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size ...@@ -454,10 +454,7 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id); return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id);
} }
bool AscendKernelRuntime::Load(session::KernelGraph *graph) { bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
if (!is_task_sink) { if (!is_task_sink) {
return true; return true;
} }
...@@ -609,17 +606,14 @@ void AscendKernelRuntime::DebugTaskIdName(GraphId graph_id) { ...@@ -609,17 +606,14 @@ void AscendKernelRuntime::DebugTaskIdName(GraphId graph_id) {
} }
} }
bool AscendKernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) { bool AscendKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger) {
bool ret = false; bool ret = false;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
#if defined(_WIN32) || defined(_WIN64) #if defined(_WIN32) || defined(_WIN64)
auto start_time = std::chrono::steady_clock::now(); auto start_time = std::chrono::steady_clock::now();
#else #else
struct timeval start_time, end_time; struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr); (void)gettimeofday(&start_time, nullptr);
#endif #endif
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
if (is_task_sink) { if (is_task_sink) {
ret = RunTask(graph); ret = RunTask(graph);
} else { } else {
......
...@@ -44,8 +44,8 @@ class AscendKernelRuntime : public KernelRuntime { ...@@ -44,8 +44,8 @@ class AscendKernelRuntime : public KernelRuntime {
bool GenTask(const session::KernelGraph *graph); bool GenTask(const session::KernelGraph *graph);
bool LoadTask(const session::KernelGraph *graph); bool LoadTask(const session::KernelGraph *graph);
bool RunTask(const session::KernelGraph *graph); bool RunTask(const session::KernelGraph *graph);
bool Load(session::KernelGraph *graph) override; bool Load(session::KernelGraph *graph, bool is_task_sink) override;
bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override; bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override;
void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs, void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,
const std::unordered_set<ValueNodePtr> &value_nodes, const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order) override; const std::vector<CNodePtr> &execution_order) override;
......
...@@ -287,7 +287,7 @@ void CPUKernelRuntime::DecreaseSummaryRefCount(const session::NamedSummaryOutput ...@@ -287,7 +287,7 @@ void CPUKernelRuntime::DecreaseSummaryRefCount(const session::NamedSummaryOutput
resource_manager_.DecreaseSummaryRefCount(summary_outputs); resource_manager_.DecreaseSummaryRefCount(summary_outputs);
} }
bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph, Debugger *debugger) { bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph, bool is_task_sink, Debugger *debugger) {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
resource_manager_.IncreaseAddressRefCount(kernel_graph); resource_manager_.IncreaseAddressRefCount(kernel_graph);
......
...@@ -36,7 +36,7 @@ class CPUKernelRuntime : public KernelRuntime { ...@@ -36,7 +36,7 @@ class CPUKernelRuntime : public KernelRuntime {
~CPUKernelRuntime() override = default; ~CPUKernelRuntime() override = default;
bool Init() override { return true; } bool Init() override { return true; }
bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override; bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override;
void AssignKernelAddress(session::KernelGraph *kernel_graph); void AssignKernelAddress(session::KernelGraph *kernel_graph);
void BindInputOutput(session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs, void BindInputOutput(session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs); VectorRef *outputs);
......
...@@ -433,7 +433,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { ...@@ -433,7 +433,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {
} }
} }
bool GPUKernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) { bool GPUKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger) {
struct timeval start_time, end_time; struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr); (void)gettimeofday(&start_time, nullptr);
bool ret = true; bool ret = true;
......
...@@ -42,7 +42,7 @@ class GPUKernelRuntime : public KernelRuntime { ...@@ -42,7 +42,7 @@ class GPUKernelRuntime : public KernelRuntime {
const std::unordered_set<ValueNodePtr> &value_nodes, const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order) override; const std::vector<CNodePtr> &execution_order) override;
void AssignMemory(session::KernelGraph *graph) override; void AssignMemory(session::KernelGraph *graph) override;
bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override; bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override;
#ifdef ENABLE_DUMP_E2E #ifdef ENABLE_DUMP_E2E
bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr) override; bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
#endif #endif
......
...@@ -40,7 +40,7 @@ KernelRuntime::~KernelRuntime() { ...@@ -40,7 +40,7 @@ KernelRuntime::~KernelRuntime() {
#endif #endif
} }
bool KernelRuntime::Load(session::KernelGraph *graph) { return true; } bool KernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { return true; }
bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *debugger) { bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *debugger) {
if (graph != nullptr) { if (graph != nullptr) {
......
...@@ -59,8 +59,8 @@ class KernelRuntime { ...@@ -59,8 +59,8 @@ class KernelRuntime {
bool DumpDataEnabled(); bool DumpDataEnabled();
bool DumpDataEnabledIteration(); bool DumpDataEnabledIteration();
virtual bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr); virtual bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr);
virtual bool Load(session::KernelGraph *graph); virtual bool Load(session::KernelGraph *graph, bool is_task_sink);
virtual bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) = 0; virtual bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) = 0;
bool LaunchKernel(const session::KernelGraph *graph); bool LaunchKernel(const session::KernelGraph *graph);
bool LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, const AddressPtrList &kernel_inputs, bool LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, const AddressPtrList &kernel_inputs,
const AddressPtrList &kernel_outputs, const AddressPtrList &kernel_outputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册