diff --git a/mindspore/ccsrc/backend/session/CMakeLists.txt b/mindspore/ccsrc/backend/session/CMakeLists.txt index b6d340a3258ba222407d654ffa1e30584932fec9..b97900abd2add62fe4152be6e238f08103f46b10 100644 --- a/mindspore/ccsrc/backend/session/CMakeLists.txt +++ b/mindspore/ccsrc/backend/session/CMakeLists.txt @@ -3,6 +3,8 @@ file(GLOB_RECURSE _SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "kernel_graph.cc" "session_basic.cc" "session_factory.cc" + "executor.cc" + "executor_manager.cc" "anf_runtime_algorithm.cc" ) diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 63e9fa1cc0ac82b9f718bef01a72cf1d049df794..e416f0b93a34e9cd14ff0b84b1ff4a1d4cb94731 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -312,7 +312,6 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vectorexecutable()) { MS_LOG(INFO) << "No child graph has anf output"; - UpdateOutputs(kernel_graph, outputs, inputs); return; } // load input data from user input @@ -322,12 +321,9 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector &input_tensors) { +void AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, VectorRef *outputs) { auto graph = run_op_graphs_[graph_info]; MS_EXCEPTION_IF_NULL(graph); MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!"; @@ -408,7 +404,6 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr // run op RunOpExecTask(graph); // get output - VectorRef outputs; if (op_run_info.value != nullptr) { std::vector pre_output_tensors; TensorValueToTensor(op_run_info.value, &pre_output_tensors); @@ -416,22 +411,13 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr tensor::TensorPtr tensor = std::make_shared(pre_output->data_type(), pre_output->shape()); tensor->set_device_address(pre_output->device_address()); tensor->set_dirty(false); - outputs.emplace_back(tensor); + outputs->emplace_back(tensor); } } else { - UpdateOutputs(graph, &outputs, input_tensors); + UpdateOutputs(graph, outputs, input_tensors); } - // trans output to tuple - auto output_tensors = TransformBaseRefListToTuple(outputs); - if (!utils::isa(output_tensors) || - !py::isinstance(utils::cast(output_tensors).object_)) { - MS_LOG(EXCEPTION) << "The output tensors should be a tuple !"; - } - py::object tuple_obj = utils::cast(output_tensors).object_; - py::tuple tuple_tensors = py::cast(tuple_obj); RunOpMemoryClear(graph.get()); MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!"; - return tuple_tensors; } // compile graph steps diff --git a/mindspore/ccsrc/backend/session/ascend_session.h b/mindspore/ccsrc/backend/session/ascend_session.h index e8b624163c449ddba5b1a736cb2c09f30ff888ff..91a86e6d10bd823d538ec0a7aec2d1680eadce5f 100755 --- a/mindspore/ccsrc/backend/session/ascend_session.h +++ b/mindspore/ccsrc/backend/session/ascend_session.h @@ -29,7 +29,7 @@ #include "backend/kernel_compiler/kernel.h" #include "backend/session/session_factory.h" #include "backend/session/ascend_control_parser.h" -#include "runtime/device/ascend/ascend_memory_pool.h" +#include "runtime/context.h" namespace mindspore { namespace session { @@ -38,10 +38,17 @@ enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, class AscendSession : public SessionBasic { public: AscendSession() { final_graph_id_ = kInvalidGraphId; } - ~AscendSession() override { mindspore::device::ascend::AscendMemoryPool::GetInstance().ResetIdleMemBuf(); } + ~AscendSession() override = default; void Init(uint32_t device_id) override { - SessionBasic::Init(device_id); - context_ = std::make_shared(kAscendDevice, device_id); + InitDevice(kAscendDevice, device_id); + auto ret = rtCtxCreate(&rt_context_, 0, device_id); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast(ret) << "]"; + } + ret = rtCtxSetCurrent(rt_context_); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]"; + } } GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; GraphId CompileGraph(NotNull func_graph) override; @@ -49,8 +56,8 @@ class AscendSession : public SessionBasic { void BuildGraph(GraphId) override; void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors, const std::vector &tensors_mask) override; - py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors) override; + void RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, VectorRef *outputs) override; // get graph id in child graphs by ME front anf node pointer GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override; @@ -121,6 +128,8 @@ class AscendSession : public SessionBasic { std::map, tensor::TensorPtr> initial_tenosrs_; // final_graph_id is used in every root graph has it's own session situation GraphId final_graph_id_; + // ascend runtime context + rtContext_t rt_context_{nullptr}; }; MS_REG_SESSION(kAscendDevice, AscendSession); } // namespace session diff --git a/mindspore/ccsrc/backend/session/cpu_session.cc b/mindspore/ccsrc/backend/session/cpu_session.cc index 9a9b404f11771a64e2b980f054725ad367b51503..5397d1580862d668bd1b59536d752b1666af1f1e 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.cc +++ b/mindspore/ccsrc/backend/session/cpu_session.cc @@ -83,15 +83,23 @@ GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList return graph_id; } +void CPUSession::CreateOutputTensors(const GraphId &graph_id, const std::vector &input_tensors, + VectorRef *outputs, + std::map *tensor_to_node) { + auto kernel_graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_LOG(INFO) << "Bind input output address"; + runtime_.BindInputOutput(kernel_graph.get(), input_tensors, outputs); + return; +} + void CPUSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) { - auto &kernel_graph = graphs_[graph_id]; + auto kernel_graph = GetGraph(graph_id); MS_EXCEPTION_IF_NULL(kernel_graph); #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) InitPSParamAndOptim(kernel_graph, inputs); #endif - MS_LOG(INFO) << "Bind input output address"; - std::vector need_sync_outputs; - runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs, &need_sync_outputs); + MS_LOG(INFO) << "Run graph start"; auto execution_order = kernel_graph->execution_order(); Reorder(&execution_order); @@ -114,9 +122,6 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vectordata_sync(); - } if (enable_summary) { Summary(kernel_graph.get()); diff --git a/mindspore/ccsrc/backend/session/cpu_session.h b/mindspore/ccsrc/backend/session/cpu_session.h index 014b4168ab04c115e82f43cad3dac46d5b2a6c43..08e09d929eaf0bc93e7b9afbc9804bd192c29ebd 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.h +++ b/mindspore/ccsrc/backend/session/cpu_session.h @@ -17,6 +17,7 @@ #define MINDSPORE_CCSRC_BACKEND_SESSION_CPU_SESSION_H #include #include +#include #include #include "backend/session/session_basic.h" #include "backend/session/kernel_graph.h" @@ -28,13 +29,13 @@ class CPUSession : public SessionBasic { public: CPUSession() = default; ~CPUSession() override = default; - void Init(uint32_t device_id) override { - SessionBasic::Init(device_id); - context_ = std::make_shared(kCPUDevice, device_id); - } + void Init(uint32_t device_id) override { InitDevice(kCPUDevice, device_id); } GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; + void CreateOutputTensors(const GraphId &graph_id, const std::vector &input_tensors, VectorRef *, + std::map *tensor_to_node) override; + protected: ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) override; void Optimize(const std::shared_ptr &kernel_graph); diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..9353d6931d61bac2088c6575461ed82c015ab357 --- /dev/null +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -0,0 +1,292 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/session/executor.h" +#include "runtime/device/kernel_runtime_manager.h" +#include "backend/session/executor_manager.h" + +namespace mindspore { +namespace session { +namespace { +void UpdateOutputTensors(VectorRef *outputs, + const std::map &tensor_to_node) { + MS_EXCEPTION_IF_NULL(outputs); + for (auto item : *outputs) { + if (utils::isa(item)) { + auto vector_ref = utils::cast(item); + UpdateOutputTensors(&vector_ref, tensor_to_node); + } else if (utils::isa(item)) { + auto tensor = utils::cast(item); + MS_EXCEPTION_IF_NULL(tensor); + tensor->SetNeedWait(false); + auto iter = tensor_to_node.find(tensor); + if (iter != tensor_to_node.end()) { + auto &node = iter->second.first; + auto &output_index = iter->second.second; + auto address = AnfAlgo::GetMutableOutputAddr(node, output_index); + tensor->set_device_address(address); + } + if (tensor->need_sync()) { + tensor->data_sync(); + tensor->set_need_sync(false); + } + } + } +} + +BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) { + if (utils::isa(base_ref)) { + auto ref_list = utils::cast(base_ref); + py::tuple output_tensors(ref_list.size()); + for (size_t i = 0; i < ref_list.size(); ++i) { + auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef + if (utils::isa(output)) { + auto tensor_ptr = utils::cast(output); + MS_EXCEPTION_IF_NULL(tensor_ptr); + output_tensors[i] = tensor_ptr; + } else if (utils::isa(output)) { + py::object obj = utils::cast(output).object_; + py::tuple tensor_tuple = py::cast(obj); + output_tensors[i] = tensor_tuple; + } else { + MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; + } + } + return output_tensors; // turn tuple to py::object and store in PyObjectRef + } else if (utils::isa(base_ref)) { + return base_ref; + } else { + MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; + } +} +} // namespace +void CompileNodesTask::Run() { + MS_EXCEPTION_IF_NULL(session_); + graph_id_ = session_->CompileGraph(nodes_, output_nodes_); +} + +void CompileGraphTask::Run() { + MS_EXCEPTION_IF_NULL(session_); + graph_id_ = session_->CompileGraph(NOT_NULL(func_graph_)); +} + +void BuildGraphTask::Run() { + MS_EXCEPTION_IF_NULL(session_); + session_->BuildGraph(graph_id_); +} + +void RunGraphTask::Run() { + MS_EXCEPTION_IF_NULL(session_); + session_->RunGraph(graph_id_, input_tensors_, &outputs_); + UpdateOutputTensors(&outputs_, tensor_to_node_); + ExecutorManager::Instance().OnRunGraphFinished(); +} + +void BuildOpTask::Run() { + MS_EXCEPTION_IF_NULL(session_); + session_->BuildOp(*op_run_info_, graph_info_, input_tensors_, tensors_mask_); +} + +void RunOpTask::Run() { + MS_EXCEPTION_IF_NULL(session_); + session_->RunOp(*op_run_info_, graph_info_, input_tensors_, &outputs_); +} + +Executor::Executor(const std::string &device_name, uint32_t device_id) { + device_name_ = device_name; + device_id_ = device_id; + worker_ = std::make_shared(&Executor::WorkerLoop, this); +} + +void Executor::WorkerJoin() { + StopWorker(); + worker_->join(); +} + +void Executor::WorkerLoop() { + while (true) { + std::shared_ptr task; + { + std::unique_lock lock(task_mutex_); + task_cond_var_.wait(lock, [this] { return !ready_tasks_.empty(); }); + task = ready_tasks_.front(); + ready_tasks_.pop(); + } + if (task->type_ == kExit) { + OnWorkerExit(); + return; + } + task->Run(); + if (task->type_ == kCompileNodes) { + compile_cond_var_.notify_all(); + } else if (task->type_ == kCompileGraph) { + compile_cond_var_.notify_all(); + } else if (task->type_ == kBuildGraph) { + build_cond_var_.notify_all(); + } else if (task->type_ == kRunGraph) { + run_cond_var_.notify_all(); + } else if (task->type_ == kBuildOp) { + build_op_cond_var_.notify_all(); + } else if (task->type_ == kRunOp) { + run_op_cond_var_.notify_all(); + } + } +} + +std::vector> Executor::GetNewReadyTasks() { + std::vector> new_ready_tasks; + std::unique_lock lock(pending_task_mutex_); + for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) { + auto task = *iter; + if (IsAllInputsReady(task->input_tensors_)) { + new_ready_tasks.emplace_back(task); + pending_tasks_.erase(iter++); + } else { + iter++; + } + } + return new_ready_tasks; +} + +void Executor::OnRunGraphFinished() { + auto new_ready_tasks = GetNewReadyTasks(); + std::unique_lock lock(task_mutex_); + for (auto &task : new_ready_tasks) { + ready_tasks_.push(task); + } + if (new_ready_tasks.size() > 0) { + task_cond_var_.notify_all(); + } +} + +bool Executor::IsAllInputsReady(const std::vector &inputs) { + for (auto &input : inputs) { + MS_EXCEPTION_IF_NULL(input); + if (input->NeedWait()) { + return false; + } + } + return true; +} + +GraphId Executor::CompileGraphAsync(const SessionPtr &session, const AnfNodePtrList &lst, + const AnfNodePtrList &outputs) { + std::unique_lock lock(task_mutex_); + auto task = std::make_shared(); + task->session_ = session; + task->nodes_ = lst; + task->output_nodes_ = outputs; + ready_tasks_.push(task); + task_cond_var_.notify_all(); + compile_cond_var_.wait(lock); + return task->graph_id_; +} + +GraphId Executor::CompileGraphAsync(const SessionPtr &session, NotNull func_graph) { + std::unique_lock lock(task_mutex_); + auto task = std::make_shared(); + task->session_ = session; + task->func_graph_ = func_graph; + ready_tasks_.push(task); + task_cond_var_.notify_all(); + compile_cond_var_.wait(lock); + return task->graph_id_; +} + +void Executor::BuildGraphAsync(const SessionPtr &session, GraphId graphId) { + std::unique_lock lock(task_mutex_); + auto task = std::make_shared(); + task->session_ = session; + task->graph_id_ = graphId; + ready_tasks_.push(task); + task_cond_var_.notify_all(); + build_cond_var_.wait(lock); +} + +void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, + const std::vector &inputs, VectorRef *outputs) { + auto task = std::make_shared(); + task->session_ = session; + task->graph_id_ = graph_id; + task->input_tensors_ = inputs; + MS_EXCEPTION_IF_NULL(session); + session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_); + // maintain a copy of output vector + task->outputs_ = *outputs; + + bool ready = IsAllInputsReady(inputs); + if (!ready) { + std::unique_lock lock(pending_task_mutex_); + pending_tasks_.push_back(task); + return; + } + std::unique_lock lock(task_mutex_); + ready_tasks_.push(task); + task_cond_var_.notify_all(); + py::gil_scoped_release release; + run_cond_var_.wait(lock); +} + +void Executor::BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, const std::vector &tensors_mask) { + std::unique_lock lock(task_mutex_); + auto task = std::make_shared(); + task->session_ = session; + task->op_run_info_ = op_run_info; + task->graph_info_ = graph_info; + task->input_tensors_ = input_tensors; + task->tensors_mask_ = tensors_mask; + ready_tasks_.push(task); + task_cond_var_.notify_all(); + build_op_cond_var_.wait(lock); +} + +py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors) { + std::unique_lock lock(task_mutex_); + auto task = std::make_shared(); + task->session_ = session; + task->op_run_info_ = op_run_info; + task->graph_info_ = graph_info; + task->input_tensors_ = input_tensors; + ready_tasks_.push(task); + task_cond_var_.notify_all(); + run_op_cond_var_.wait(lock); + + // Trans output to tuple + auto output_tensors = TransformBaseRefListToTuple(task->outputs_); + if (!utils::isa(output_tensors) || + !py::isinstance(utils::cast(output_tensors).object_)) { + MS_EXCEPTION(NotSupportError) << "The output tensors should be a tuple !"; + } + py::object tuple_obj = utils::cast(output_tensors).object_; + py::tuple tuple_tensors = py::cast(tuple_obj); + return tuple_tensors; +} + +void Executor::StopWorker() { + std::unique_lock lock(task_mutex_); + auto task = std::make_shared(); + ready_tasks_.push(task); + task_cond_var_.notify_all(); +} + +void Executor::OnWorkerExit() { + if (device_name_ == kAscendDevice) { + device::KernelRuntimeManager::Instance().ReleaseKernelRuntime(kAscendDevice, device_id_); + } +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/executor.h b/mindspore/ccsrc/backend/session/executor.h new file mode 100644 index 0000000000000000000000000000000000000000..3078c41f1d1a492217afde2accc13aabc3bc0d0e --- /dev/null +++ b/mindspore/ccsrc/backend/session/executor.h @@ -0,0 +1,155 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H +#define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "backend/session/session_basic.h" +#include "ir/anf.h" +#include "ir/tensor.h" +#include "utils/any.h" +#include "utils/contract.h" + +namespace mindspore { +namespace session { +enum TaskType { kUnKnown, kExit, kCompileNodes, kCompileGraph, kBuildGraph, kBuildOp, kRunGraph, kRunOp }; + +class Task { + public: + Task() = default; + virtual ~Task() = default; + SessionPtr session_{nullptr}; + TaskType type_{kUnKnown}; + virtual void Run() {} +}; + +class CompileNodesTask : public Task { + public: + CompileNodesTask() { type_ = kCompileNodes; } + ~CompileNodesTask() override = default; + void Run() override; + AnfNodePtrList nodes_; + AnfNodePtrList output_nodes_; + GraphId graph_id_{0}; +}; + +class CompileGraphTask : public Task { + public: + CompileGraphTask() { type_ = kCompileGraph; } + ~CompileGraphTask() override = default; + void Run() override; + FuncGraphPtr func_graph_{nullptr}; + GraphId graph_id_{0}; +}; + +class BuildGraphTask : public Task { + public: + BuildGraphTask() { type_ = kBuildGraph; } + ~BuildGraphTask() override = default; + void Run() override; + GraphId graph_id_{0}; +}; + +class RunGraphTask : public Task { + public: + RunGraphTask() { type_ = kRunGraph; } + ~RunGraphTask() override = default; + void Run() override; + std::vector input_tensors_; + VectorRef outputs_; + GraphId graph_id_{0}; + std::map tensor_to_node_; +}; + +class BuildOpTask : public Task { + public: + BuildOpTask() { type_ = kBuildOp; } + ~BuildOpTask() override = default; + void Run() override; + OpRunInfo *op_run_info_{nullptr}; + GraphInfo graph_info_; + std::vector input_tensors_; + std::vector tensors_mask_; +}; + +class RunOpTask : public Task { + public: + RunOpTask() { type_ = kRunOp; } + ~RunOpTask() override = default; + void Run() override; + OpRunInfo *op_run_info_{nullptr}; + GraphInfo graph_info_; + std::vector input_tensors_; + VectorRef outputs_; +}; + +class ExitTask : public Task { + public: + ExitTask() { type_ = kExit; } + ~ExitTask() override = default; +}; + +class Executor { + public: + Executor(const std::string &device_name, uint32_t device_id); + ~Executor() = default; + void WorkerLoop(); + void WorkerJoin(); + GraphId CompileGraphAsync(const SessionPtr &session, const AnfNodePtrList &lst, const AnfNodePtrList &outputs); + GraphId CompileGraphAsync(const SessionPtr &session, NotNull func_graph); + void BuildGraphAsync(const SessionPtr &session, GraphId graphId); + void RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, const std::vector &inputs, + VectorRef *outputs); + void BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, const std::vector &tensors_mask); + py::tuple RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors); + void OnRunGraphFinished(); + + protected: + void UpdateOutputTensors(VectorRef *outputs, + const std::map &tensor_to_node); + std::vector> GetNewReadyTasks(); + bool IsAllInputsReady(const std::vector &inputs); + void StopWorker(); + void OnWorkerExit(); + + uint32_t device_id_; + std::string device_name_; + std::mutex task_mutex_; + std::mutex pending_task_mutex_; + std::condition_variable task_cond_var_; + std::condition_variable compile_cond_var_; + std::condition_variable build_cond_var_; + std::condition_variable run_cond_var_; + std::condition_variable build_op_cond_var_; + std::condition_variable run_op_cond_var_; + std::queue> ready_tasks_; + std::list> pending_tasks_; + std::shared_ptr worker_; +}; +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H diff --git a/mindspore/ccsrc/backend/session/executor_manager.cc b/mindspore/ccsrc/backend/session/executor_manager.cc new file mode 100644 index 0000000000000000000000000000000000000000..3758adcf2e5785473df3a3c8226d0f7bbcd2b4b8 --- /dev/null +++ b/mindspore/ccsrc/backend/session/executor_manager.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/session/executor_manager.h" + +namespace mindspore { +namespace session { +std::shared_ptr ExecutorManager::GetExecutor(const std::string &device_name, int device_id) { + std::string device_key = device_name + "_" + std::to_string(device_id); + auto iter = executors_.find(device_key); + if (iter != executors_.end()) { + return iter->second; + } + auto executor = std::make_shared(device_name, device_id); + executors_[device_key] = executor; + return executor; +} + +void ExecutorManager::OnRunGraphFinished() { + for (auto &item : executors_) { + auto &executor = item.second; + if (executor != nullptr) { + executor->OnRunGraphFinished(); + } + } +} + +void ExecutorManager::JoinExecutorWorkers() { + for (auto &item : executors_) { + auto &executor = item.second; + if (executor != nullptr) { + executor->WorkerJoin(); + } + } +} + +void ExecutorManager::Clear() { + JoinExecutorWorkers(); + executors_.clear(); +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/executor_manager.h b/mindspore/ccsrc/backend/session/executor_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..3dc4f6ee17c4c186beb74d75f2f436c86e1a07a6 --- /dev/null +++ b/mindspore/ccsrc/backend/session/executor_manager.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_MANGER_H_ +#define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_MANGER_H_ +#include +#include +#include +#include +#include "backend/session/executor.h" +namespace mindspore { +namespace session { +class Executor; +class ExecutorManager { + public: + static ExecutorManager &Instance() { + static ExecutorManager instance; + return instance; + } + std::shared_ptr GetExecutor(const std::string &device_name, int device_id); + void OnRunGraphFinished(); + void Clear(); + + private: + ExecutorManager() = default; + ~ExecutorManager() = default; + DISABLE_COPY_AND_ASSIGN(ExecutorManager) + void JoinExecutorWorkers(); + std::map> executors_; +}; +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_MANGER_H_ diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 037780a8c807144dadda54610898af18396a2e71..23c29019e5def6d2af6e29673a8625ae7337fb52 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -260,16 +260,10 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector &input_tensors) { +void GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, VectorRef *outputs) { auto kernel_graph = run_op_graphs_[graph_info]; MS_EXCEPTION_IF_NULL(kernel_graph); // Remove NopOp from execution graph @@ -307,12 +301,8 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph RunOpAllocateMemory(op_run_info.value, input_tensors, kernel_graph.get()); // Execute the computation LoadInputData(kernel_graph, input_tensors); - { - py::gil_scoped_release gil_release; - Execute(kernel_graph); - } + Execute(kernel_graph); // Fetch outputs - VectorRef outputs; if (op_run_info.value != nullptr) { std::vector pre_output_tensors; TensorValueToTensor(op_run_info.value, &pre_output_tensors); @@ -320,21 +310,12 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph tensor::TensorPtr tensor = std::make_shared(pre_output->data_type(), pre_output->shape()); tensor->set_device_address(pre_output->device_address()); tensor->set_dirty(false); - outputs.emplace_back(tensor); + outputs->emplace_back(tensor); } } else { - UpdateOutputs(kernel_graph, &outputs, input_tensors); - } - // Trans output to tuple - auto output_tensors = TransformBaseRefListToTuple(outputs); - if (!utils::isa(output_tensors) || - !py::isinstance(utils::cast(output_tensors).object_)) { - MS_EXCEPTION(NotSupportError) << "The output tensors should be a tuple !"; + UpdateOutputs(kernel_graph, outputs, input_tensors); } - py::object tuple_obj = utils::cast(output_tensors).object_; - py::tuple tuple_tensors = py::cast(tuple_obj); RunOpClearMemory(kernel_graph.get()); - return tuple_tensors; } #ifdef ENABLE_DEBUGGER diff --git a/mindspore/ccsrc/backend/session/gpu_session.h b/mindspore/ccsrc/backend/session/gpu_session.h index 70d904ef7aebc8853ef173d5d9ada581e3911372..f79ae4e8d56b0fe9c8b26a0d1e7d374149663123 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.h +++ b/mindspore/ccsrc/backend/session/gpu_session.h @@ -18,6 +18,7 @@ #include #include +#include #include "backend/session/session_basic.h" #include "backend/session/kernel_graph.h" #include "backend/session/session_factory.h" @@ -31,18 +32,15 @@ class GPUSession : public SessionBasic { GPUSession() = default; ~GPUSession() override = default; - void Init(uint32_t device_id) override { - SessionBasic::Init(device_id); - context_ = std::make_shared(kGPUDevice, device_id); - } + void Init(uint32_t device_id) override { InitDevice(kGPUDevice, device_id); } GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors, const std::vector &tensors_mask) override; - py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors) override; + void RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, VectorRef *outputs) override; private: void SelectKernel(const std::shared_ptr &kernel_graph) const; diff --git a/mindspore/ccsrc/backend/session/infer_session.cc b/mindspore/ccsrc/backend/session/infer_session.cc index 1cff6a3b7c31881343b4834e40906f1f3b61ef86..e9985d8a4d1a2796e272860140a498b07c0e5188 100644 --- a/mindspore/ccsrc/backend/session/infer_session.cc +++ b/mindspore/ccsrc/backend/session/infer_session.cc @@ -318,7 +318,7 @@ void MSInferSession::RegAllOp() { Status MSInferSession::CompileGraph(std::shared_ptr funcGraphPtr, uint32_t &model_id) { MS_ASSERT(session_impl_ != nullptr); try { - auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); + auto graph_id = session_impl_->CompileGraphAsync(NOT_NULL(funcGraphPtr)); py::gil_scoped_release gil_release; model_id = graph_id; return SUCCESS; @@ -332,7 +332,7 @@ std::vector MSInferSession::RunGraph(uint32_t graph_id, const std::vector &inputs) { try { VectorRef outputs; - session_impl_->RunGraph(graph_id, inputs, &outputs); + session_impl_->RunGraphAsync(graph_id, inputs, &outputs); return TransformVectorRefToMultiTensor(outputs); } catch (std::exception &e) { @@ -364,16 +364,16 @@ Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) { return FAILED; } ms_context->set_device_target(device); + if (!context::OpenTsd(ms_context)) { + MS_LOG(ERROR) << "Session init OpenTsd failed!"; + return FAILED; + } session_impl_ = session::SessionFactory::Get().Create(ajust_device); if (session_impl_ == nullptr) { MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available."; return FAILED; } session_impl_->Init(device_id); - if (!context::OpenTsd(ms_context)) { - MS_LOG(ERROR) << "Session init OpenTsd failed!"; - return FAILED; - } return SUCCESS; } diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 5e9863516c8310cbd14fd0c60a1bc9e43abaa474..8a195f32db3b4904da65d192d1b374fa6cc0ca8c 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -25,8 +25,11 @@ #include "common/trans.h" #include "utils/config_manager.h" #include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/executor.h" +#include "backend/session/executor_manager.h" #include "backend/optimizer/common/common_backend_optimization.h" #include "backend/optimizer/common/helper.h" +#include "runtime/device/kernel_runtime_manager.h" #include "utils/ms_utils.h" #include "ir/dtype.h" #include "ir/anf.h" @@ -55,8 +58,10 @@ ValuePtr GetParamDefaultValue(const AnfNodePtr &node) { return parameter->default_param(); } -tensor::TensorPtr CreateOutputTensor(const AnfNodePtr &node, size_t output_index, const KernelGraphPtr &graph, - const DeviceAddressPtr &address) { +tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_output_pair, + const KernelGraphPtr &graph) { + auto &node = node_output_pair.first; + auto &output_index = node_output_pair.second; MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(graph); TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index); @@ -68,9 +73,9 @@ tensor::TensorPtr CreateOutputTensor(const AnfNodePtr &node, size_t output_index if (graph->IsUniqueTargetInternalOutput(node, output_index)) { temp_shape.emplace_back(1); tensor = std::make_shared(type_id, temp_shape); - tensor->set_device_address(address); tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index)); tensor->set_dirty(false); + tensor->SetNeedWait(true); return tensor; } @@ -88,23 +93,25 @@ tensor::TensorPtr CreateOutputTensor(const AnfNodePtr &node, size_t output_index // if in paynative mode,data only copyed to host when user want to print data auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - MS_EXCEPTION_IF_NULL(address); - if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) { - tensor->set_device_address(address); - tensor->set_dirty(false); - } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index), - LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) { - MS_LOG(INFO) << "Output sync device to host error!!!"; - tensor->set_dirty(false); + if (ms_context->execution_mode() != kPynativeMode && ms_context->device_target() != kGPUDevice) { + tensor->set_need_sync(true); + } + if (ms_context->execution_mode() != kPynativeMode) { + tensor->SetNeedWait(true); } + tensor->set_dirty(false); return tensor; } -BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraphPtr &graph, - const std::vector &input_tensors) { +BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph, + const std::vector &input_tensors, + std::map *tensor_to_node) { + auto &node = node_output_pair.first; + auto &output_index = node_output_pair.second; MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(graph); - MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]"; + MS_EXCEPTION_IF_NULL(tensor_to_node); + MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << node_output_pair.second << "]"; // if node is a value node, no need sync addr from device to host if (node->isa()) { auto value_node = node->cast(); @@ -124,13 +131,16 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr"; } } - auto address = AnfAlgo::GetMutableOutputAddr(node, output_index); - return CreateOutputTensor(node, output_index, graph, address); + auto tensor = CreateCNodeOutputTensor(node_output_pair, graph); + (*tensor_to_node)[tensor] = node_output_pair; + return tensor; } -BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraphPtr &graph, - const std::vector &input_tensors) { +BaseRef CreateNodeOutputTensors(const AnfNodePtr &anf, const KernelGraphPtr &graph, + const std::vector &input_tensors, + std::map *tensor_to_node) { MS_EXCEPTION_IF_NULL(anf); + MS_EXCEPTION_IF_NULL(tensor_to_node); MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]"; auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0); MS_EXCEPTION_IF_NULL(item_with_index.first); @@ -141,7 +151,7 @@ BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraphPtr &graph MS_EXCEPTION_IF_NULL(cnode); VectorRef ret; for (size_t i = 1; i < cnode->inputs().size(); ++i) { - auto out = CreateTensorForOutput(cnode->input(i), graph, input_tensors); + auto out = CreateNodeOutputTensors(cnode->input(i), graph, input_tensors, tensor_to_node); ret.push_back(out); } return ret; @@ -151,7 +161,7 @@ BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraphPtr &graph if (size == 0) { return VectorRef(); } - return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors); + return CreateNodeOutputTensor(item_with_index, graph, input_tensors, tensor_to_node); } ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) { @@ -321,6 +331,12 @@ bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) { GraphId SessionBasic::graph_sum_ = 0; +void SessionBasic::InitDevice(const std::string &device_name, uint32_t device_id) { + device_id_ = device_id; + context_ = std::make_shared(device_name, device_id); + executor_ = ExecutorManager::Instance().GetExecutor(device_name, device_id); +} + KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const { auto it = graphs_.find(graph_id); if (it == graphs_.end()) { @@ -982,11 +998,36 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr &kernel_grap const std::vector &input_tensors) const { MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(outputs); + std::map tensor_to_node; auto anf_outputs = kernel_graph->outputs(); for (auto &item : anf_outputs) { MS_EXCEPTION_IF_NULL(item); MS_LOG(INFO) << "Update output[" << item->DebugString() << "]"; - outputs->emplace_back(CreateTensorForOutput(item, kernel_graph, input_tensors)); + outputs->emplace_back(CreateNodeOutputTensors(item, kernel_graph, input_tensors, &tensor_to_node)); + } + + for (auto &item : tensor_to_node) { + auto &tensor = item.first; + auto &node = item.second.first; + auto &output_index = item.second.second; + auto address = AnfAlgo::GetMutableOutputAddr(node, output_index); + tensor->set_device_address(address); + tensor->SetNeedWait(false); + } +} + +void SessionBasic::CreateOutputTensors(const GraphId &graph_id, const std::vector &input_tensors, + VectorRef *outputs, + std::map *tensor_to_node) { + auto kernel_graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(outputs); + MS_EXCEPTION_IF_NULL(tensor_to_node); + auto anf_outputs = kernel_graph->outputs(); + for (auto &item : anf_outputs) { + MS_EXCEPTION_IF_NULL(item); + MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]"; + outputs->emplace_back(CreateNodeOutputTensors(item, kernel_graph, input_tensors, tensor_to_node)); } } @@ -1231,32 +1272,6 @@ std::shared_ptr SessionBasic::ConstructSingleOpGraph(const OpRunInf return graph; } -BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) { - if (utils::isa(base_ref)) { - auto ref_list = utils::cast(base_ref); - py::tuple output_tensors(ref_list.size()); - for (size_t i = 0; i < ref_list.size(); ++i) { - auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef - if (utils::isa(output)) { - auto tensor_ptr = utils::cast(output); - MS_EXCEPTION_IF_NULL(tensor_ptr); - output_tensors[i] = tensor_ptr; - } else if (utils::isa(output)) { - py::object obj = utils::cast(output).object_; - py::tuple tensor_tuple = py::cast(obj); - output_tensors[i] = tensor_tuple; - } else { - MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; - } - } - return output_tensors; // turn tuple to py::object and store in PyObjectRef - } else if (utils::isa(base_ref)) { - return base_ref; - } else { - MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; - } -} - KernelGraphPtr SessionBasic::NewKernelGraph() { auto graph = std::make_shared(); graph->set_graph_id(graph_sum_); @@ -1281,6 +1296,40 @@ AnfNodePtr SessionBasic::FindPullNode(const AnfNodePtr &push_node, const std::ve return nullptr; } +GraphId SessionBasic::CompileGraphAsync(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { + MS_EXCEPTION_IF_NULL(executor_); + return executor_->CompileGraphAsync(shared_from_this(), lst, outputs); +} + +GraphId SessionBasic::CompileGraphAsync(NotNull func_graph) { + MS_EXCEPTION_IF_NULL(executor_); + return executor_->CompileGraphAsync(shared_from_this(), func_graph); +} + +void SessionBasic::BuildGraphAsync(GraphId graph_id) { + MS_EXCEPTION_IF_NULL(executor_); + executor_->BuildGraphAsync(shared_from_this(), graph_id); +} + +void SessionBasic::BuildOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, + const std::vector &tensors_mask) { + MS_EXCEPTION_IF_NULL(executor_); + executor_->BuildOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors, tensors_mask); +} + +py::tuple SessionBasic::RunOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors) { + MS_EXCEPTION_IF_NULL(executor_); + return executor_->RunOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors); +} + +void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector &inputs, + VectorRef *outputs) { + MS_EXCEPTION_IF_NULL(executor_); + executor_->RunGraphAsync(shared_from_this(), graph_id, inputs, outputs); +} + #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { if (!parallel::ps::Util::IsRoleOfWorker()) { diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index bcdc89eb8fec1c4ef6c9d08c7a8fe67383e3058f..1e963722be040f2cd838f350ee88e3843699e808 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -22,10 +22,10 @@ #include #include #include - #include "utils/base_ref_extends.h" #include "backend/session/session_context.h" #include "backend/session/kernel_graph.h" +#include "backend/session/anf_runtime_algorithm.h" #include "ir/anf.h" #include "ir/tensor.h" #include "utils/any.h" @@ -49,8 +49,8 @@ using AnyListPtr = std::shared_ptr; using OpRunInfo = pynative::OpExecInfo; using OpRunInfoPtr = std::shared_ptr; - -class SessionBasic { +class Executor; +class SessionBasic : public std::enable_shared_from_this { public: SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) { #ifdef ENABLE_DEBUGGER @@ -60,6 +60,12 @@ class SessionBasic { virtual void Init(uint32_t device_id) { device_id_ = device_id; } + void InitDevice(const std::string &device_name, uint32_t device_id); + + virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector &input_tensors, + VectorRef *outputs, + std::map *tensor_to_node); + virtual ~SessionBasic() { summary_callback_ = nullptr; } virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; @@ -69,12 +75,19 @@ class SessionBasic { virtual void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) = 0; - virtual void BuildOp(const OpRunInfo &, const GraphInfo &, const std::vector &input_tensors, - const std::vector &tensors_mask) {} + virtual void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, const std::vector &tensors_mask) {} - virtual py::tuple RunOp(const OpRunInfo &, const GraphInfo &, const std::vector &input_tensors) { - return py::tuple(); - } + virtual void RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, VectorRef *outputs) {} + + GraphId CompileGraphAsync(const AnfNodePtrList &lst, const AnfNodePtrList &outputs); + GraphId CompileGraphAsync(NotNull func_graph); + void BuildGraphAsync(GraphId graphId); + void RunGraphAsync(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs); + void BuildOpAsync(OpRunInfo *, const GraphInfo &, const std::vector &input_tensors, + const std::vector &tensors_mask); + py::tuple RunOpAsync(OpRunInfo *, const GraphInfo &, const std::vector &input_tensors); virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); @@ -116,9 +129,11 @@ class SessionBasic { void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector *cnode_inputs); protected: - virtual void SetSummaryNodes(KernelGraph *graph); // Get graph by graph id ,if not exist return null ptr KernelGraphPtr GetGraph(GraphId graph_id) const; + + virtual void SetSummaryNodes(KernelGraph *graph); + virtual void LoadInputData(const std::shared_ptr &kernel_graph, const std::vector &inputs_const) const; void UpdateOutputs(const std::shared_ptr &kernel_graph, VectorRef *const outputs, @@ -132,8 +147,6 @@ class SessionBasic { std::shared_ptr ConstructSingleOpGraph(const OpRunInfo &op_run_info, const std::vector &input_tensors, const std::vector &tensors_mask); - // trans BaseRef list to py::tuple - BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); // create a new kernel graph and update the graph sum KernelGraphPtr NewKernelGraph(); std::vector CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph); @@ -152,6 +165,7 @@ class SessionBasic { CallBackFunc summary_callback_; static GraphId graph_sum_; uint32_t device_id_; + std::shared_ptr executor_; #ifdef ENABLE_DEBUGGER std::shared_ptr debugger_; #endif diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index afd59d29aed5a477c440d829c3f09b93fc834342..5429821b4f468ae4f560c69efc609466b83575dd 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -37,6 +37,7 @@ #include "frontend/parallel/context.h" #include "frontend/parallel/graph_util/get_parallel_info.h" #include "runtime/device/kernel_runtime_manager.h" +#include "backend/session/executor_manager.h" #include "debug/trace.h" #include "pipeline/pynative/pynative_execute.h" #include "frontend/optimizer/py_pass_manager.h" @@ -1023,7 +1024,6 @@ void ClearResAtexit() { MS_LOG(DEBUG) << "Pipeline clear all resource"; pynative::ClearPyNativeSession(); session::ClearPythonParasMap(); - device::KernelRuntimeManager::Instance().ClearRuntimeResource(); #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) if (mindspore::parallel::ps::Util::IsParamServerMode()) { if (parallel::ps::Util::IsRoleOfWorker()) { @@ -1047,6 +1047,8 @@ void ClearResAtexit() { #else ConfigManager::GetInstance().ResetIterNum(); #endif + session::ExecutorManager::Instance().Clear(); + device::KernelRuntimeManager::Instance().ClearRuntimeResource(); ReleaseGeTsd(); parse::python_adapter::ResetPythonScope(); } diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 0145f4656bc1a85d126f5559d808f2f26178226e..ca314753697e329b8acb1aae110f7b6cf6bc9088 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -574,9 +574,9 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors); // get graph info for checking it whether existing in the cache std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors); - session->BuildOp(*op_exec_info, graph_info, input_tensors, tensors_mask); + session->BuildOpAsync(op_exec_info.get(), graph_info, input_tensors, tensors_mask); EraseValueNodeTensor(tensors_mask, &input_tensors); - py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors); + py::tuple result = session->RunOpAsync(op_exec_info.get(), graph_info, input_tensors); ms_context->set_enable_pynative_infer(false); *status = PYNATIVE_SUCCESS; MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms"; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index 4734f482bcf054b11103530934925c865e6f97cf..7f2bde2e082f0f9af58e932101465bdb8171783d 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -94,7 +94,18 @@ std::string GetRankId() { AscendKernelRuntime::~AscendKernelRuntime() { graph_model_map_.clear(); } +void AscendKernelRuntime::SetContext() { + if (rt_context_ == nullptr) { + return; + } + auto ret = rtCtxSetCurrent(rt_context_); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]"; + } +} + void AscendKernelRuntime::ClearGraphModelMap() { + SetContext(); for (auto &iter : graph_data_dumper_) { MS_LOG(INFO) << "[DataDump] Unload data dumper:" << iter.first; auto &data_dumper = iter.second; @@ -118,6 +129,7 @@ void AscendKernelRuntime::ClearGraphModelMap() { void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector &, const std::unordered_set &, const std::vector &) { + SetContext(); MS_LOG(DEBUG) << "Clear graph:" << graph_id << " data dumper"; if (auto dumper_iter = graph_data_dumper_.find(graph_id); dumper_iter != graph_data_dumper_.end()) { MS_LOG(DEBUG) << "Unload dump info " << graph_id; @@ -156,6 +168,10 @@ bool AscendKernelRuntime::NeedDestroyHccl() { void AscendKernelRuntime::ReleaseDeviceRes() { MS_LOG(INFO) << "Ascend finalize start"; + if (!initialized_) { + return; + } + SetContext(); // release ge runtime ClearGraphModelMap(); @@ -438,6 +454,7 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size } bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { + SetContext(); if (graph == nullptr) { MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!"; } @@ -494,6 +511,7 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { } bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { + SetContext(); if (graph == nullptr) { MS_EXCEPTION(NotExistsError) << "Null pointer graph, LoadTask failed. "; } @@ -586,6 +604,7 @@ void AscendKernelRuntime::DebugTaskIdName(GraphId graph_id) { } bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { + SetContext(); MS_EXCEPTION_IF_NULL(graph); MS_LOG(INFO) << "RunTask start. GraphId:" << graph->graph_id(); @@ -613,6 +632,7 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { } bool AscendKernelRuntime::SyncStream() { + SetContext(); if (RT_ERROR_NONE != rtStreamSynchronize(stream_)) { // o for switch stream MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error."; return false; @@ -649,12 +669,7 @@ bool AscendKernelRuntime::InitDevice() { if (ret != RT_ERROR_NONE) { MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast(ret) << "]"; } - - ret = rtCtxSetCurrent(rt_context_); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]"; - } - + SetContext(); ret = rtStreamCreate(&stream_, 0); if (ret != RT_ERROR_NONE) { MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]"; @@ -664,14 +679,9 @@ bool AscendKernelRuntime::InitDevice() { } bool AscendKernelRuntime::ResetDevice() { - auto ret = rtCtxSetCurrent(rt_context_); - if (ret != RT_ERROR_NONE) { - MS_LOG(ERROR) << "Call rtCtxSetCurrent failed"; - return false; - } - + SetContext(); if (stream_ != nullptr) { - ret = rtStreamDestroy(stream_); + auto ret = rtStreamDestroy(stream_); if (ret != RT_ERROR_NONE) { MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]"; } @@ -679,7 +689,7 @@ bool AscendKernelRuntime::ResetDevice() { } if (rt_context_ != nullptr) { - ret = rtCtxDestroy(rt_context_); + auto ret = rtCtxDestroy(rt_context_); if (ret != RT_ERROR_NONE) { MS_EXCEPTION(DeviceProcessError) << "Call rtCtxDestroy, ret[" << ret << "]"; } diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h index 42b9dcda990658cb8f30b0286edc21a88c55da16..6b7ccc085cdcdaed6f8a2a498e63038526639dab 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -60,6 +60,7 @@ class AscendKernelRuntime : public KernelRuntime { bool HcclInit(); bool NeedDestroyHccl(); bool DestroyHccl(); + void SetContext(); void ClearGraphModelMap(); void ReleaseDeviceRes() override; diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc index 2ce7f70cb364dcf1cf25e135bab5641c47a589aa..07d62b32786f41a83354830af75ce9bf2770a877 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include "backend/kernel_compiler/kernel.h" #include "runtime/device/cpu/cpu_device_address.h" #include "utils/ms_context.h" @@ -137,10 +138,8 @@ DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t } tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, - size_t index, - std::vector *need_sync_outputs) { + size_t index) { MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(need_sync_outputs); size_t output_size = AnfAlgo::GetOutputTensorNum(node); if (index >= output_size) { MS_LOG(EXCEPTION) << "Invalid input index " << index; @@ -163,16 +162,15 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k } if (bound_addresses_.find(address) != bound_addresses_.end()) { tensor->set_device_address(address); - need_sync_outputs->emplace_back(tensor); + tensor->set_need_sync(true); } else { if (infer_type_id != device_type_id) { size_t type_size = GetTypeByte(TypeIdToType(device_type_id)); ShapeVector data_shape = tensor->shape(); size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies()); address->ptr_ = resource_manager_.MemMalloc(tensor_size); - need_sync_outputs->emplace_back(tensor); tensor->set_device_address(address); - need_sync_outputs->emplace_back(tensor); + tensor->set_need_sync(true); } else { tensor->set_device_address(nullptr); address->ptr_ = tensor->data_c(); @@ -185,8 +183,7 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k } BaseRef CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *kernel_graph, - const session::KernelWithIndex &kernel_with_index, - std::vector *need_sync_outputs) { + const session::KernelWithIndex &kernel_with_index) { auto &input_node = kernel_with_index.first; auto index = kernel_with_index.second; MS_EXCEPTION_IF_NULL(input_node); @@ -197,12 +194,12 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *kernel_grap VectorRef ret; for (size_t i = 1; i < node->inputs().size(); i++) { auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node->input(i), 0); - auto out = CreatTensorForOutput(kernel_graph, item_with_index, need_sync_outputs); + auto out = CreatTensorForOutput(kernel_graph, item_with_index); ret.push_back(out); } return ret; } - return CreatTensorForOutput(kernel_graph, node, index, need_sync_outputs); + return CreatTensorForOutput(kernel_graph, node, index); } else if (input_node->isa()) { auto iter = input_param_tensor_map_.find(input_node); if (iter != input_param_tensor_map_.end()) { @@ -216,7 +213,7 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *kernel_grap return BaseRef(); } void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const std::vector &inputs, - VectorRef *outputs, std::vector *need_sync_outputs) { + VectorRef *outputs) { MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(outputs); // bind input ptr @@ -262,7 +259,7 @@ void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const auto output_nodes = kernel_graph->outputs(); for (const auto &item : output_nodes) { auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true); - auto out = CreatTensorForOutput(kernel_graph, item_with_index, need_sync_outputs); + auto out = CreatTensorForOutput(kernel_graph, item_with_index); outputs->push_back(std::move(out)); } } diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h index e391332f85ea22e9b293cf5dab0df939142945e6..ff448f15695beceb7a6da113af9e57714768c77c 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h @@ -39,7 +39,7 @@ class CPUKernelRuntime : public KernelRuntime { bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override; void AssignKernelAddress(session::KernelGraph *kernel_graph); void BindInputOutput(session::KernelGraph *kernel_graph, const std::vector &inputs, - VectorRef *outputs, std::vector *need_sync_outputs); + VectorRef *outputs); void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); @@ -49,11 +49,9 @@ class CPUKernelRuntime : public KernelRuntime { TypeId type_id) override; private: - tensor::TensorPtr CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index, - std::vector *need_sync_outputs); + tensor::TensorPtr CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index); - BaseRef CreatTensorForOutput(session::KernelGraph *kernel_graph, const session::KernelWithIndex &kernel_with_index, - std::vector *need_sync_outputs); + BaseRef CreatTensorForOutput(session::KernelGraph *kernel_graph, const session::KernelWithIndex &kernel_with_index); void AssignValueNodeAddress(session::KernelGraph *kernel_graph); void AssignInputNodeAddress(const session::KernelGraph *kernel_graph); void AssignKernelOutputAddress(const session::KernelGraph *kernel_graph); diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc index 0c7c66e3c83b9362688fa9705a0b5aea03a30048..c0395b3a434017d3d1dde34501ad9011440075f6 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc @@ -49,8 +49,13 @@ void KernelRuntimeManager::Register(const std::string &device_name, KernelRuntim } } +std::string KernelRuntimeManager::GetDeviceKey(const std::string &device_name, uint32_t device_id) { + std::string device_key = device_name + "_" + std::to_string(device_id); + return device_key; +} + KernelRuntime *KernelRuntimeManager::GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id) { - std::string runtime_key = device_name + "_" + std::to_string(device_id); + auto runtime_key = GetDeviceKey(device_name, device_id); auto runtime_iter = runtime_map_.find(runtime_key); if (runtime_iter != runtime_map_.end()) { return runtime_iter->second.get(); @@ -72,8 +77,8 @@ KernelRuntime *KernelRuntimeManager::GetSingleKernelRuntime(const std::string &d } KernelRuntime *KernelRuntimeManager::GetKernelRuntime(const std::string &device_name, uint32_t device_id) { + std::string runtime_key = GetDeviceKey(device_name, device_id); std::lock_guard guard(lock_); - std::string runtime_key = device_name + "_" + std::to_string(device_id); auto runtime_iter = runtime_map_.find(runtime_key); if (runtime_iter != runtime_map_.end()) { return runtime_iter->second.get(); @@ -92,5 +97,20 @@ KernelRuntime *KernelRuntimeManager::GetKernelRuntime(const std::string &device_ return kernel_runtime.get(); } + +void KernelRuntimeManager::ReleaseKernelRuntime(const std::string &device_name, uint32_t device_id) { + std::string runtime_key = GetDeviceKey(device_name, device_id); + std::lock_guard guard(lock_); + auto runtime_iter = runtime_map_.find(runtime_key); + if (runtime_iter == runtime_map_.end()) { + return; + } + auto runtime = runtime_iter->second.get(); + if (runtime == nullptr) { + return; + } + runtime->ReleaseDeviceRes(); + runtime_map_.erase(runtime_iter); +} } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.h b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.h index 26e0ff8804186761be102fbd823079f0028f26be..31dec77eefaab44311442d7e3a51c527b623c79c 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.h @@ -39,6 +39,7 @@ class KernelRuntimeManager { void Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator); KernelRuntime *GetKernelRuntime(const std::string &device_name, uint32_t device_id); KernelRuntime *GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id); + void ReleaseKernelRuntime(const std::string &device_name, uint32_t device_id); void ClearRuntimeResource(); void ClearGraphResource(uint32_t graph_id, const std::vector &inputs, const std::unordered_set &value_nodes, @@ -48,6 +49,7 @@ class KernelRuntimeManager { KernelRuntimeManager() = default; ~KernelRuntimeManager() = default; DISABLE_COPY_AND_ASSIGN(KernelRuntimeManager); + std::string GetDeviceKey(const std::string &device_name, uint32_t device_id); std::map > runtime_map_; std::map runtime_creators_; std::mutex lock_; diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index 8c587a6227b75cca9f89e523f2e1b55a078586b1..52200d1ecc8ba0e5c602532829885fed9a01e82c 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -54,9 +54,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri GraphId graph_id = kInvalidGraphId; if (target != target_device_ && !target.empty()) { CreateOtherSession(target); - graph_id = other_sess_->CompileGraph(lst, outputs); + graph_id = other_sess_->CompileGraphAsync(lst, outputs); } else { - graph_id = target_sess_->CompileGraph(lst, outputs); + graph_id = target_sess_->CompileGraphAsync(lst, outputs); } if (MsContext::GetInstance()->precompile_only()) { @@ -64,9 +64,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri return result; } if (target != target_device_ && !target.empty()) { - other_sess_->BuildGraph(graph_id); + other_sess_->BuildGraphAsync(graph_id); } else if (!is_multi_graph_sink_) { - target_sess_->BuildGraph(graph_id); + target_sess_->BuildGraphAsync(graph_id); } result.run = std::make_shared( [graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); }); @@ -137,9 +137,9 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s VectorRef outputs; // call ms rungraph (graphId, input ,output) if (target != target_device_ && !target.empty()) { - other_sess_->RunGraph(g, inputs, &outputs); + other_sess_->RunGraphAsync(g, inputs, &outputs); } else { - target_sess_->RunGraph(g, inputs, &outputs); + target_sess_->RunGraphAsync(g, inputs, &outputs); } MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size(); @@ -150,7 +150,7 @@ void MsBackend::Link(GraphId graph_id) { if (graph_id == kInvalidGraphId) { graph_id = target_sess_->GetFinalRunGraph(); } - target_sess_->BuildGraph(graph_id); + target_sess_->BuildGraphAsync(graph_id); } Backend::Backend(const std::string &name) : name_(name) { @@ -186,7 +186,7 @@ void MsBackend::CreateOtherSession(const std::string &target) { other_device_ = target; } -GraphId MsBackend::CompileGraph(NotNull fg) { return target_sess_->CompileGraph(fg); } +GraphId MsBackend::CompileGraph(NotNull fg) { return target_sess_->CompileGraphAsync(fg); } VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); } diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc index ac121fa9f4ec843d6811d91f12d3521badf22ea1..5d9e47194a83b5524084e6e2f0adcf7e1d774eb9 100644 --- a/mindspore/core/ir/tensor.cc +++ b/mindspore/core/ir/tensor.cc @@ -424,6 +424,8 @@ Tensor::Tensor(const Tensor &tensor) data_(tensor.data_), dirty_(tensor.dirty_), id_(tensor.id_), + event_(tensor.event_), + need_sync_(tensor.need_sync_), device_sync_(tensor.device_sync_), padding_type_(tensor.padding_type()) {} @@ -433,6 +435,8 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type) data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)), dirty_(tensor.dirty_), id_(tensor.id_), + event_(tensor.event_), + need_sync_(tensor.need_sync_), device_sync_(tensor.device_sync_), padding_type_(tensor.padding_type()) {} @@ -483,6 +487,8 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) { device_sync_ = tensor.device_sync_; data_ = tensor.data_; id_ = tensor.id_; + event_ = tensor.event_; + need_sync_ = tensor.need_sync_; padding_type_ = tensor.padding_type_; } return *this; @@ -547,6 +553,7 @@ std::string Tensor::ToStringRepr() const { } void Tensor::data_sync() const { + const_cast(this)->Wait(); if (device_sync_ != nullptr) { if (!device_sync_->SyncDeviceToHost(shape(), static_cast(data().nbytes()), data_type(), data_c())) { MS_LOG(EXCEPTION) << "SyncDeviceToHost failed."; diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index 13ef1cbab21a1ed96037236d05065a1b19070235..1228131c1975f4f53415b1ffffd848cef6f62328 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include "ir/device_sync.h" #include "ir/meta_tensor.h" @@ -73,6 +75,30 @@ class TensorData { using TensorDataPtr = std::shared_ptr; +struct WaitEvent { + bool need_wait_{false}; + std::mutex mutex_; + std::condition_variable cond_var_; + + void Wait() { + std::unique_lock lock(mutex_); + if (!need_wait_) { + return; + } + cond_var_.wait(lock, [this] { return !need_wait_; }); + } + + void set_need_wait(bool need_wait) { + std::unique_lock lock(mutex_); + need_wait_ = need_wait; + if (!need_wait_) { + cond_var_.notify_all(); + } + } + + bool need_wait() const { return need_wait_; } +}; + // Tensor entity class class Tensor : public MetaTensor { public: @@ -244,11 +270,40 @@ class Tensor : public MetaTensor { std::string id() const { return id_; } + void SetNeedWait(bool need_wait) { + if (event_ != nullptr) { + event_->set_need_wait(need_wait); + } else if (need_wait) { + event_ = std::make_shared(); + event_->set_need_wait(need_wait); + } + } + + bool NeedWait() const { + if (event_ != nullptr) { + return event_->need_wait(); + } + return false; + } + + void Wait() { + if (event_ != nullptr) { + event_->Wait(); + } + event_ == nullptr; + } + + void set_need_sync(bool need_sync) { need_sync_ = need_sync; } + + bool need_sync() const { return need_sync_; } + private: bool init_flag_{false}; TensorDataPtr data_{nullptr}; bool dirty_{true}; std::string id_{""}; + std::shared_ptr event_{nullptr}; + bool need_sync_{false}; DeviceSyncPtr device_sync_{nullptr}; std::vector padding_type_; }; diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index c822a8c219fd001035574fdbd9cab3dddc594009..c677abded748152d12e81e23b8b813c02d1f309e 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -84,6 +84,8 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/ccsrc/backend/session/ascend_control_parser.cc" "../../../mindspore/ccsrc/backend/session/kernel_graph.cc" "../../../mindspore/ccsrc/backend/session/session_basic.cc" + "../../../mindspore/ccsrc/backend/session/executor.cc" + "../../../mindspore/ccsrc/backend/session/executor_manager.cc" "../../../mindspore/ccsrc/backend/session/session_factory.cc" "../../../mindspore/ccsrc/backend/session/kernel_build_client.cc" "../../../mindspore/ccsrc/vm/*.cc"