提交 bbafa9db 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5832 Use the unified Execute function to run Graph or Single Graph.

Merge pull request !5832 from 张清华/master
......@@ -318,7 +318,7 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::
#endif
{
// run task on device
Execute(kernel_graph);
Execute(kernel_graph, true);
}
// summary
Summary(kernel_graph.get());
......@@ -348,17 +348,6 @@ void AscendSession::RunOpHardwareOptimize(const std::shared_ptr<session::KernelG
MS_LOG(INFO) << "Finish";
}
void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "Start!";
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
bool ret_ok = runtime_instance->LaunchKernel(kernel_graph.get());
if (!ret_ok) {
MS_LOG(EXCEPTION) << "Run task error!";
}
MS_LOG(INFO) << "Finish!";
}
bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const {
return run_op_graphs_.find(graph_info) != run_op_graphs_.end();
}
......@@ -398,7 +387,7 @@ void AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_i
// load input data to device
LoadInputData(graph, input_tensors);
// run op
RunOpExecTask(graph);
Execute(graph, false);
// get output
if (op_run_info.value != nullptr) {
std::vector<tensor::TensorPtr> pre_output_tensors;
......@@ -552,21 +541,30 @@ void AscendSession::RunOpMemoryClear(const KernelGraph *kernel_graph) const {
void AscendSession::Load(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "Start!";
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
(void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
bool ret_ok = runtime_instance->Load(kernel_graph.get());
bool ret_ok = runtime_instance->Load(kernel_graph.get(), is_task_sink);
if (!ret_ok) {
MS_LOG(EXCEPTION) << "Load task error!";
}
MS_LOG(INFO) << "Finish!";
}
void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const {
void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const {
MS_LOG(INFO) << "Start!";
bool is_task_sink = false;
if (is_task) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
}
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
bool ret_ok = runtime_instance->Run(kernel_graph.get());
bool ret_ok = runtime_instance->Run(kernel_graph.get(), is_task_sink);
if (!ret_ok) {
MS_LOG(EXCEPTION) << "run task error!";
}
......
......@@ -13,8 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H
#define MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H
#include <unordered_map>
#include <string>
#include <memory>
......@@ -82,13 +84,12 @@ class AscendSession : public SessionBasic {
KernelGraph *kernel_graph) const;
void RunOpMemoryClear(const KernelGraph *kernel_graph) const;
void Load(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const;
void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs);
void LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const;
// below functions are used for run op
void RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const;
void RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;
static void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs);
static void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
......
......@@ -118,7 +118,7 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
debugger_->PreExecute(kernel_graph);
}
#endif
bool ret = runtime_.Run(kernel_graph.get());
bool ret = runtime_.Run(kernel_graph.get(), false);
if (!ret) {
MS_LOG(EXCEPTION) << "Run graph failed";
}
......
......@@ -191,9 +191,9 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
#ifdef ENABLE_DEBUGGER
if (!runtime_instance->Run(kernel_graph.get(), debugger_.get())) {
if (!runtime_instance->Run(kernel_graph.get(), false, debugger_.get())) {
#else
if (!runtime_instance->Run(kernel_graph.get())) {
if (!runtime_instance->Run(kernel_graph.get(), false)) {
#endif
MS_LOG(EXCEPTION) << "GPU execute graph failed!";
}
......
......@@ -454,10 +454,7 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id);
}
bool AscendKernelRuntime::Load(session::KernelGraph *graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) {
if (!is_task_sink) {
return true;
}
......@@ -609,17 +606,14 @@ void AscendKernelRuntime::DebugTaskIdName(GraphId graph_id) {
}
}
bool AscendKernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) {
bool AscendKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger) {
bool ret = false;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
#if defined(_WIN32) || defined(_WIN64)
auto start_time = std::chrono::steady_clock::now();
#else
struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr);
#endif
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
if (is_task_sink) {
ret = RunTask(graph);
} else {
......
......@@ -44,8 +44,8 @@ class AscendKernelRuntime : public KernelRuntime {
bool GenTask(const session::KernelGraph *graph);
bool LoadTask(const session::KernelGraph *graph);
bool RunTask(const session::KernelGraph *graph);
bool Load(session::KernelGraph *graph) override;
bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
bool Load(session::KernelGraph *graph, bool is_task_sink) override;
bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override;
void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,
const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order) override;
......
......@@ -287,7 +287,7 @@ void CPUKernelRuntime::DecreaseSummaryRefCount(const session::NamedSummaryOutput
resource_manager_.DecreaseSummaryRefCount(summary_outputs);
}
bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph, Debugger *debugger) {
bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph, bool is_task_sink, Debugger *debugger) {
MS_EXCEPTION_IF_NULL(kernel_graph);
resource_manager_.IncreaseAddressRefCount(kernel_graph);
......
......@@ -36,7 +36,7 @@ class CPUKernelRuntime : public KernelRuntime {
~CPUKernelRuntime() override = default;
bool Init() override { return true; }
bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override;
void AssignKernelAddress(session::KernelGraph *kernel_graph);
void BindInputOutput(session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs);
......
......@@ -433,7 +433,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {
}
}
bool GPUKernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) {
bool GPUKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger) {
struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr);
bool ret = true;
......
......@@ -42,7 +42,7 @@ class GPUKernelRuntime : public KernelRuntime {
const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order) override;
void AssignMemory(session::KernelGraph *graph) override;
bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override;
#ifdef ENABLE_DUMP_E2E
bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
#endif
......
......@@ -40,7 +40,7 @@ KernelRuntime::~KernelRuntime() {
#endif
}
bool KernelRuntime::Load(session::KernelGraph *graph) { return true; }
bool KernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { return true; }
bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *debugger) {
if (graph != nullptr) {
......
......@@ -59,8 +59,8 @@ class KernelRuntime {
bool DumpDataEnabled();
bool DumpDataEnabledIteration();
virtual bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr);
virtual bool Load(session::KernelGraph *graph);
virtual bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) = 0;
virtual bool Load(session::KernelGraph *graph, bool is_task_sink);
virtual bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) = 0;
bool LaunchKernel(const session::KernelGraph *graph);
bool LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, const AddressPtrList &kernel_inputs,
const AddressPtrList &kernel_outputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册