提交 1821e98e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4087 asyn run graph

Merge pull request !4087 from kisnwang/async-run-graph
......@@ -3,6 +3,8 @@ file(GLOB_RECURSE _SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"kernel_graph.cc"
"session_basic.cc"
"session_factory.cc"
"executor.cc"
"executor_manager.cc"
"anf_runtime_algorithm.cc"
)
......
......@@ -312,7 +312,6 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::
// if none of child graph and no anf output exists
if (!kernel_graph->executable()) {
MS_LOG(INFO) << "No child graph has anf output";
UpdateOutputs(kernel_graph, outputs, inputs);
return;
}
// load input data from user input
......@@ -322,12 +321,9 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::
InitPSParamAndOptim(kernel_graph, inputs);
#endif
{
py::gil_scoped_release release;
// run task on device
ExecTask(kernel_graph);
}
// get result from device
UpdateOutputs(kernel_graph, outputs, inputs);
// summary
Summary(kernel_graph.get());
#ifdef ENABLE_DEBUGGER
......@@ -396,8 +392,8 @@ void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph
MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !";
}
py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors) {
void AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
auto graph = run_op_graphs_[graph_info];
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!";
......@@ -408,7 +404,6 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr
// run op
RunOpExecTask(graph);
// get output
VectorRef outputs;
if (op_run_info.value != nullptr) {
std::vector<tensor::TensorPtr> pre_output_tensors;
TensorValueToTensor(op_run_info.value, &pre_output_tensors);
......@@ -416,22 +411,13 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape());
tensor->set_device_address(pre_output->device_address());
tensor->set_dirty(false);
outputs.emplace_back(tensor);
outputs->emplace_back(tensor);
}
} else {
UpdateOutputs(graph, &outputs, input_tensors);
UpdateOutputs(graph, outputs, input_tensors);
}
// trans output to tuple
auto output_tensors = TransformBaseRefListToTuple(outputs);
if (!utils::isa<PyObjectRef>(output_tensors) ||
!py::isinstance<py::tuple>(utils::cast<PyObjectRef>(output_tensors).object_)) {
MS_LOG(EXCEPTION) << "The output tensors should be a tuple !";
}
py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_;
py::tuple tuple_tensors = py::cast<py::tuple>(tuple_obj);
RunOpMemoryClear(graph.get());
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!";
return tuple_tensors;
}
// compile graph steps
......
......@@ -29,7 +29,7 @@
#include "backend/kernel_compiler/kernel.h"
#include "backend/session/session_factory.h"
#include "backend/session/ascend_control_parser.h"
#include "runtime/device/ascend/ascend_memory_pool.h"
#include "runtime/context.h"
namespace mindspore {
namespace session {
......@@ -38,10 +38,17 @@ enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2,
class AscendSession : public SessionBasic {
public:
AscendSession() { final_graph_id_ = kInvalidGraphId; }
~AscendSession() override { mindspore::device::ascend::AscendMemoryPool::GetInstance().ResetIdleMemBuf(); }
~AscendSession() override = default;
void Init(uint32_t device_id) override {
SessionBasic::Init(device_id);
context_ = std::make_shared<Context>(kAscendDevice, device_id);
InitDevice(kAscendDevice, device_id);
auto ret = rtCtxCreate(&rt_context_, 0, device_id);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]";
}
ret = rtCtxSetCurrent(rt_context_);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]";
}
}
GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override;
......@@ -49,8 +56,8 @@ class AscendSession : public SessionBasic {
void BuildGraph(GraphId) override;
void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) override;
py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors) override;
void RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) override;
// get graph id in child graphs by ME front anf node pointer
GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override;
......@@ -121,6 +128,8 @@ class AscendSession : public SessionBasic {
std::map<std::pair<GraphId, size_t>, tensor::TensorPtr> initial_tenosrs_;
// final_graph_id is used in every root graph has it's own session situation
GraphId final_graph_id_;
// ascend runtime context
rtContext_t rt_context_{nullptr};
};
MS_REG_SESSION(kAscendDevice, AscendSession);
} // namespace session
......
......@@ -83,15 +83,23 @@ GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
return graph_id;
}
void CPUSession::CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors,
VectorRef *outputs,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
auto kernel_graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_LOG(INFO) << "Bind input output address";
runtime_.BindInputOutput(kernel_graph.get(), input_tensors, outputs);
return;
}
void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
auto &kernel_graph = graphs_[graph_id];
auto kernel_graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
InitPSParamAndOptim(kernel_graph, inputs);
#endif
MS_LOG(INFO) << "Bind input output address";
std::vector<tensor::TensorPtr> need_sync_outputs;
runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs, &need_sync_outputs);
MS_LOG(INFO) << "Run graph start";
auto execution_order = kernel_graph->execution_order();
Reorder(&execution_order);
......@@ -114,9 +122,6 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
if (!ret) {
MS_LOG(EXCEPTION) << "Run graph failed";
}
for (auto output : need_sync_outputs) {
(void)output->data_sync();
}
if (enable_summary) {
Summary(kernel_graph.get());
......
......@@ -17,6 +17,7 @@
#define MINDSPORE_CCSRC_BACKEND_SESSION_CPU_SESSION_H
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "backend/session/session_basic.h"
#include "backend/session/kernel_graph.h"
......@@ -28,13 +29,13 @@ class CPUSession : public SessionBasic {
public:
CPUSession() = default;
~CPUSession() override = default;
void Init(uint32_t device_id) override {
SessionBasic::Init(device_id);
context_ = std::make_shared<Context>(kCPUDevice, device_id);
}
void Init(uint32_t device_id) override { InitDevice(kCPUDevice, device_id); }
GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) override;
protected:
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) override;
void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph);
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/session/executor.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "backend/session/executor_manager.h"
namespace mindspore {
namespace session {
namespace {
void UpdateOutputTensors(VectorRef *outputs,
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
MS_EXCEPTION_IF_NULL(outputs);
for (auto item : *outputs) {
if (utils::isa<VectorRefPtr>(item)) {
auto vector_ref = utils::cast<VectorRef>(item);
UpdateOutputTensors(&vector_ref, tensor_to_node);
} else if (utils::isa<tensor::TensorPtr>(item)) {
auto tensor = utils::cast<tensor::TensorPtr>(item);
MS_EXCEPTION_IF_NULL(tensor);
tensor->SetNeedWait(false);
auto iter = tensor_to_node.find(tensor);
if (iter != tensor_to_node.end()) {
auto &node = iter->second.first;
auto &output_index = iter->second.second;
auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
tensor->set_device_address(address);
}
if (tensor->need_sync()) {
tensor->data_sync();
tensor->set_need_sync(false);
}
}
}
}
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) {
if (utils::isa<VectorRef>(base_ref)) {
auto ref_list = utils::cast<VectorRef>(base_ref);
py::tuple output_tensors(ref_list.size());
for (size_t i = 0; i < ref_list.size(); ++i) {
auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef
if (utils::isa<tensor::TensorPtr>(output)) {
auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
MS_EXCEPTION_IF_NULL(tensor_ptr);
output_tensors[i] = tensor_ptr;
} else if (utils::isa<PyObjectRef>(output)) {
py::object obj = utils::cast<PyObjectRef>(output).object_;
py::tuple tensor_tuple = py::cast<py::tuple>(obj);
output_tensors[i] = tensor_tuple;
} else {
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
}
}
return output_tensors; // turn tuple to py::object and store in PyObjectRef
} else if (utils::isa<tensor::TensorPtr>(base_ref)) {
return base_ref;
} else {
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
}
}
} // namespace
void CompileNodesTask::Run() {
MS_EXCEPTION_IF_NULL(session_);
graph_id_ = session_->CompileGraph(nodes_, output_nodes_);
}
void CompileGraphTask::Run() {
MS_EXCEPTION_IF_NULL(session_);
graph_id_ = session_->CompileGraph(NOT_NULL(func_graph_));
}
void BuildGraphTask::Run() {
MS_EXCEPTION_IF_NULL(session_);
session_->BuildGraph(graph_id_);
}
void RunGraphTask::Run() {
MS_EXCEPTION_IF_NULL(session_);
session_->RunGraph(graph_id_, input_tensors_, &outputs_);
UpdateOutputTensors(&outputs_, tensor_to_node_);
ExecutorManager::Instance().OnRunGraphFinished();
}
void BuildOpTask::Run() {
MS_EXCEPTION_IF_NULL(session_);
session_->BuildOp(*op_run_info_, graph_info_, input_tensors_, tensors_mask_);
}
void RunOpTask::Run() {
MS_EXCEPTION_IF_NULL(session_);
session_->RunOp(*op_run_info_, graph_info_, input_tensors_, &outputs_);
}
Executor::Executor(const std::string &device_name, uint32_t device_id) {
device_name_ = device_name;
device_id_ = device_id;
worker_ = std::make_shared<std::thread>(&Executor::WorkerLoop, this);
}
void Executor::WorkerJoin() {
StopWorker();
worker_->join();
}
void Executor::WorkerLoop() {
while (true) {
std::shared_ptr<Task> task;
{
std::unique_lock<std::mutex> lock(task_mutex_);
task_cond_var_.wait(lock, [this] { return !ready_tasks_.empty(); });
task = ready_tasks_.front();
ready_tasks_.pop();
}
if (task->type_ == kExit) {
OnWorkerExit();
return;
}
task->Run();
if (task->type_ == kCompileNodes) {
compile_cond_var_.notify_all();
} else if (task->type_ == kCompileGraph) {
compile_cond_var_.notify_all();
} else if (task->type_ == kBuildGraph) {
build_cond_var_.notify_all();
} else if (task->type_ == kRunGraph) {
run_cond_var_.notify_all();
} else if (task->type_ == kBuildOp) {
build_op_cond_var_.notify_all();
} else if (task->type_ == kRunOp) {
run_op_cond_var_.notify_all();
}
}
}
std::vector<std::shared_ptr<RunGraphTask>> Executor::GetNewReadyTasks() {
std::vector<std::shared_ptr<RunGraphTask>> new_ready_tasks;
std::unique_lock<std::mutex> lock(pending_task_mutex_);
for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) {
auto task = *iter;
if (IsAllInputsReady(task->input_tensors_)) {
new_ready_tasks.emplace_back(task);
pending_tasks_.erase(iter++);
} else {
iter++;
}
}
return new_ready_tasks;
}
void Executor::OnRunGraphFinished() {
auto new_ready_tasks = GetNewReadyTasks();
std::unique_lock<std::mutex> lock(task_mutex_);
for (auto &task : new_ready_tasks) {
ready_tasks_.push(task);
}
if (new_ready_tasks.size() > 0) {
task_cond_var_.notify_all();
}
}
bool Executor::IsAllInputsReady(const std::vector<tensor::TensorPtr> &inputs) {
for (auto &input : inputs) {
MS_EXCEPTION_IF_NULL(input);
if (input->NeedWait()) {
return false;
}
}
return true;
}
GraphId Executor::CompileGraphAsync(const SessionPtr &session, const AnfNodePtrList &lst,
const AnfNodePtrList &outputs) {
std::unique_lock<std::mutex> lock(task_mutex_);
auto task = std::make_shared<CompileNodesTask>();
task->session_ = session;
task->nodes_ = lst;
task->output_nodes_ = outputs;
ready_tasks_.push(task);
task_cond_var_.notify_all();
compile_cond_var_.wait(lock);
return task->graph_id_;
}
GraphId Executor::CompileGraphAsync(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph) {
std::unique_lock<std::mutex> lock(task_mutex_);
auto task = std::make_shared<CompileGraphTask>();
task->session_ = session;
task->func_graph_ = func_graph;
ready_tasks_.push(task);
task_cond_var_.notify_all();
compile_cond_var_.wait(lock);
return task->graph_id_;
}
void Executor::BuildGraphAsync(const SessionPtr &session, GraphId graphId) {
std::unique_lock<std::mutex> lock(task_mutex_);
auto task = std::make_shared<BuildGraphTask>();
task->session_ = session;
task->graph_id_ = graphId;
ready_tasks_.push(task);
task_cond_var_.notify_all();
build_cond_var_.wait(lock);
}
void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
auto task = std::make_shared<RunGraphTask>();
task->session_ = session;
task->graph_id_ = graph_id;
task->input_tensors_ = inputs;
MS_EXCEPTION_IF_NULL(session);
session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
// maintain a copy of output vector
task->outputs_ = *outputs;
bool ready = IsAllInputsReady(inputs);
if (!ready) {
std::unique_lock<std::mutex> lock(pending_task_mutex_);
pending_tasks_.push_back(task);
return;
}
std::unique_lock<std::mutex> lock(task_mutex_);
ready_tasks_.push(task);
task_cond_var_.notify_all();
py::gil_scoped_release release;
run_cond_var_.wait(lock);
}
void Executor::BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) {
std::unique_lock<std::mutex> lock(task_mutex_);
auto task = std::make_shared<BuildOpTask>();
task->session_ = session;
task->op_run_info_ = op_run_info;
task->graph_info_ = graph_info;
task->input_tensors_ = input_tensors;
task->tensors_mask_ = tensors_mask;
ready_tasks_.push(task);
task_cond_var_.notify_all();
build_op_cond_var_.wait(lock);
}
py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors) {
std::unique_lock<std::mutex> lock(task_mutex_);
auto task = std::make_shared<RunOpTask>();
task->session_ = session;
task->op_run_info_ = op_run_info;
task->graph_info_ = graph_info;
task->input_tensors_ = input_tensors;
ready_tasks_.push(task);
task_cond_var_.notify_all();
run_op_cond_var_.wait(lock);
// Trans output to tuple
auto output_tensors = TransformBaseRefListToTuple(task->outputs_);
if (!utils::isa<PyObjectRef>(output_tensors) ||
!py::isinstance<py::tuple>(utils::cast<PyObjectRef>(output_tensors).object_)) {
MS_EXCEPTION(NotSupportError) << "The output tensors should be a tuple !";
}
py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_;
py::tuple tuple_tensors = py::cast<py::tuple>(tuple_obj);
return tuple_tensors;
}
void Executor::StopWorker() {
std::unique_lock<std::mutex> lock(task_mutex_);
auto task = std::make_shared<ExitTask>();
ready_tasks_.push(task);
task_cond_var_.notify_all();
}
void Executor::OnWorkerExit() {
if (device_name_ == kAscendDevice) {
device::KernelRuntimeManager::Instance().ReleaseKernelRuntime(kAscendDevice, device_id_);
}
}
} // namespace session
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H
#define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H
#include <vector>
#include <string>
#include <utility>
#include <memory>
#include <list>
#include <queue>
#include <map>
#include <thread>
#include <mutex>
#include <condition_variable>
#include "backend/session/session_basic.h"
#include "ir/anf.h"
#include "ir/tensor.h"
#include "utils/any.h"
#include "utils/contract.h"
namespace mindspore {
namespace session {
enum TaskType { kUnKnown, kExit, kCompileNodes, kCompileGraph, kBuildGraph, kBuildOp, kRunGraph, kRunOp };
class Task {
public:
Task() = default;
virtual ~Task() = default;
SessionPtr session_{nullptr};
TaskType type_{kUnKnown};
virtual void Run() {}
};
class CompileNodesTask : public Task {
public:
CompileNodesTask() { type_ = kCompileNodes; }
~CompileNodesTask() override = default;
void Run() override;
AnfNodePtrList nodes_;
AnfNodePtrList output_nodes_;
GraphId graph_id_{0};
};
class CompileGraphTask : public Task {
public:
CompileGraphTask() { type_ = kCompileGraph; }
~CompileGraphTask() override = default;
void Run() override;
FuncGraphPtr func_graph_{nullptr};
GraphId graph_id_{0};
};
class BuildGraphTask : public Task {
public:
BuildGraphTask() { type_ = kBuildGraph; }
~BuildGraphTask() override = default;
void Run() override;
GraphId graph_id_{0};
};
class RunGraphTask : public Task {
public:
RunGraphTask() { type_ = kRunGraph; }
~RunGraphTask() override = default;
void Run() override;
std::vector<tensor::TensorPtr> input_tensors_;
VectorRef outputs_;
GraphId graph_id_{0};
std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node_;
};
class BuildOpTask : public Task {
public:
BuildOpTask() { type_ = kBuildOp; }
~BuildOpTask() override = default;
void Run() override;
OpRunInfo *op_run_info_{nullptr};
GraphInfo graph_info_;
std::vector<tensor::TensorPtr> input_tensors_;
std::vector<int> tensors_mask_;
};
class RunOpTask : public Task {
public:
RunOpTask() { type_ = kRunOp; }
~RunOpTask() override = default;
void Run() override;
OpRunInfo *op_run_info_{nullptr};
GraphInfo graph_info_;
std::vector<tensor::TensorPtr> input_tensors_;
VectorRef outputs_;
};
class ExitTask : public Task {
public:
ExitTask() { type_ = kExit; }
~ExitTask() override = default;
};
class Executor {
public:
Executor(const std::string &device_name, uint32_t device_id);
~Executor() = default;
void WorkerLoop();
void WorkerJoin();
GraphId CompileGraphAsync(const SessionPtr &session, const AnfNodePtrList &lst, const AnfNodePtrList &outputs);
GraphId CompileGraphAsync(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph);
void BuildGraphAsync(const SessionPtr &session, GraphId graphId);
void RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs);
void BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask);
py::tuple RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors);
void OnRunGraphFinished();
protected:
void UpdateOutputTensors(VectorRef *outputs,
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node);
std::vector<std::shared_ptr<RunGraphTask>> GetNewReadyTasks();
bool IsAllInputsReady(const std::vector<tensor::TensorPtr> &inputs);
void StopWorker();
void OnWorkerExit();
uint32_t device_id_;
std::string device_name_;
std::mutex task_mutex_;
std::mutex pending_task_mutex_;
std::condition_variable task_cond_var_;
std::condition_variable compile_cond_var_;
std::condition_variable build_cond_var_;
std::condition_variable run_cond_var_;
std::condition_variable build_op_cond_var_;
std::condition_variable run_op_cond_var_;
std::queue<std::shared_ptr<Task>> ready_tasks_;
std::list<std::shared_ptr<RunGraphTask>> pending_tasks_;
std::shared_ptr<std::thread> worker_;
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/session/executor_manager.h"
namespace mindspore {
namespace session {
std::shared_ptr<Executor> ExecutorManager::GetExecutor(const std::string &device_name, int device_id) {
std::string device_key = device_name + "_" + std::to_string(device_id);
auto iter = executors_.find(device_key);
if (iter != executors_.end()) {
return iter->second;
}
auto executor = std::make_shared<Executor>(device_name, device_id);
executors_[device_key] = executor;
return executor;
}
void ExecutorManager::OnRunGraphFinished() {
for (auto &item : executors_) {
auto &executor = item.second;
if (executor != nullptr) {
executor->OnRunGraphFinished();
}
}
}
void ExecutorManager::JoinExecutorWorkers() {
for (auto &item : executors_) {
auto &executor = item.second;
if (executor != nullptr) {
executor->WorkerJoin();
}
}
}
void ExecutorManager::Clear() {
JoinExecutorWorkers();
executors_.clear();
}
} // namespace session
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_MANGER_H_
#define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_MANGER_H_
#include <set>
#include <map>
#include <string>
#include <memory>
#include "backend/session/executor.h"
namespace mindspore {
namespace session {
class Executor;
class ExecutorManager {
public:
static ExecutorManager &Instance() {
static ExecutorManager instance;
return instance;
}
std::shared_ptr<Executor> GetExecutor(const std::string &device_name, int device_id);
void OnRunGraphFinished();
void Clear();
private:
ExecutorManager() = default;
~ExecutorManager() = default;
DISABLE_COPY_AND_ASSIGN(ExecutorManager)
void JoinExecutorWorkers();
std::map<std::string, std::shared_ptr<Executor>> executors_;
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_MANGER_H_
......@@ -260,16 +260,10 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
InitPSParamAndOptim(kernel_graph, inputs);
#endif
MS_EXCEPTION_IF_NULL(kernel_graph);
{
py::gil_scoped_release gil_release;
// Run graph on GPU
Execute(kernel_graph);
}
Execute(kernel_graph);
#ifdef ENABLE_DEBUGGER
PostLoadTensor(kernel_graph);
#endif
// Get result from GPU
UpdateOutputs(kernel_graph, outputs, inputs);
// Summary
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
......@@ -298,8 +292,8 @@ void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_in
run_op_graphs_[graph_info] = kernel_graph;
}
py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors) {
void GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
auto kernel_graph = run_op_graphs_[graph_info];
MS_EXCEPTION_IF_NULL(kernel_graph);
// Remove NopOp from execution graph
......@@ -307,12 +301,8 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph
RunOpAllocateMemory(op_run_info.value, input_tensors, kernel_graph.get());
// Execute the computation
LoadInputData(kernel_graph, input_tensors);
{
py::gil_scoped_release gil_release;
Execute(kernel_graph);
}
Execute(kernel_graph);
// Fetch outputs
VectorRef outputs;
if (op_run_info.value != nullptr) {
std::vector<tensor::TensorPtr> pre_output_tensors;
TensorValueToTensor(op_run_info.value, &pre_output_tensors);
......@@ -320,21 +310,12 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape());
tensor->set_device_address(pre_output->device_address());
tensor->set_dirty(false);
outputs.emplace_back(tensor);
outputs->emplace_back(tensor);
}
} else {
UpdateOutputs(kernel_graph, &outputs, input_tensors);
}
// Trans output to tuple
auto output_tensors = TransformBaseRefListToTuple(outputs);
if (!utils::isa<PyObjectRef>(output_tensors) ||
!py::isinstance<py::tuple>(utils::cast<PyObjectRef>(output_tensors).object_)) {
MS_EXCEPTION(NotSupportError) << "The output tensors should be a tuple !";
UpdateOutputs(kernel_graph, outputs, input_tensors);
}
py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_;
py::tuple tuple_tensors = py::cast<py::tuple>(tuple_obj);
RunOpClearMemory(kernel_graph.get());
return tuple_tensors;
}
#ifdef ENABLE_DEBUGGER
......
......@@ -18,6 +18,7 @@
#include <vector>
#include <memory>
#include <algorithm>
#include "backend/session/session_basic.h"
#include "backend/session/kernel_graph.h"
#include "backend/session/session_factory.h"
......@@ -31,18 +32,15 @@ class GPUSession : public SessionBasic {
GPUSession() = default;
~GPUSession() override = default;
void Init(uint32_t device_id) override {
SessionBasic::Init(device_id);
context_ = std::make_shared<Context>(kGPUDevice, device_id);
}
void Init(uint32_t device_id) override { InitDevice(kGPUDevice, device_id); }
GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) override;
py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors) override;
void RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) override;
private:
void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
......
......@@ -318,7 +318,7 @@ void MSInferSession::RegAllOp() {
Status MSInferSession::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id) {
MS_ASSERT(session_impl_ != nullptr);
try {
auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
auto graph_id = session_impl_->CompileGraphAsync(NOT_NULL(funcGraphPtr));
py::gil_scoped_release gil_release;
model_id = graph_id;
return SUCCESS;
......@@ -332,7 +332,7 @@ std::vector<tensor::TensorPtr> MSInferSession::RunGraph(uint32_t graph_id,
const std::vector<tensor::TensorPtr> &inputs) {
try {
VectorRef outputs;
session_impl_->RunGraph(graph_id, inputs, &outputs);
session_impl_->RunGraphAsync(graph_id, inputs, &outputs);
return TransformVectorRefToMultiTensor(outputs);
} catch (std::exception &e) {
......@@ -364,16 +364,16 @@ Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) {
return FAILED;
}
ms_context->set_device_target(device);
if (!context::OpenTsd(ms_context)) {
MS_LOG(ERROR) << "Session init OpenTsd failed!";
return FAILED;
}
session_impl_ = session::SessionFactory::Get().Create(ajust_device);
if (session_impl_ == nullptr) {
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available.";
return FAILED;
}
session_impl_->Init(device_id);
if (!context::OpenTsd(ms_context)) {
MS_LOG(ERROR) << "Session init OpenTsd failed!";
return FAILED;
}
return SUCCESS;
}
......
......@@ -25,8 +25,11 @@
#include "common/trans.h"
#include "utils/config_manager.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/executor.h"
#include "backend/session/executor_manager.h"
#include "backend/optimizer/common/common_backend_optimization.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "utils/ms_utils.h"
#include "ir/dtype.h"
#include "ir/anf.h"
......@@ -55,8 +58,10 @@ ValuePtr GetParamDefaultValue(const AnfNodePtr &node) {
return parameter->default_param();
}
tensor::TensorPtr CreateOutputTensor(const AnfNodePtr &node, size_t output_index, const KernelGraphPtr &graph,
const DeviceAddressPtr &address) {
tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_output_pair,
const KernelGraphPtr &graph) {
auto &node = node_output_pair.first;
auto &output_index = node_output_pair.second;
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
......@@ -68,9 +73,9 @@ tensor::TensorPtr CreateOutputTensor(const AnfNodePtr &node, size_t output_index
if (graph->IsUniqueTargetInternalOutput(node, output_index)) {
temp_shape.emplace_back(1);
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
tensor->set_device_address(address);
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
tensor->set_dirty(false);
tensor->SetNeedWait(true);
return tensor;
}
......@@ -88,23 +93,25 @@ tensor::TensorPtr CreateOutputTensor(const AnfNodePtr &node, size_t output_index
// if in paynative mode,data only copyed to host when user want to print data
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
MS_EXCEPTION_IF_NULL(address);
if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) {
tensor->set_device_address(address);
tensor->set_dirty(false);
} else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) {
MS_LOG(INFO) << "Output sync device to host error!!!";
tensor->set_dirty(false);
if (ms_context->execution_mode() != kPynativeMode && ms_context->device_target() != kGPUDevice) {
tensor->set_need_sync(true);
}
if (ms_context->execution_mode() != kPynativeMode) {
tensor->SetNeedWait(true);
}
tensor->set_dirty(false);
return tensor;
}
BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
auto &node = node_output_pair.first;
auto &output_index = node_output_pair.second;
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
MS_EXCEPTION_IF_NULL(tensor_to_node);
MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << node_output_pair.second << "]";
// if node is a value node, no need sync addr from device to host
if (node->isa<ValueNode>()) {
auto value_node = node->cast<ValueNodePtr>();
......@@ -124,13 +131,16 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr";
}
}
auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
return CreateOutputTensor(node, output_index, graph, address);
auto tensor = CreateCNodeOutputTensor(node_output_pair, graph);
(*tensor_to_node)[tensor] = node_output_pair;
return tensor;
}
BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
BaseRef CreateNodeOutputTensors(const AnfNodePtr &anf, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(tensor_to_node);
MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]";
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
MS_EXCEPTION_IF_NULL(item_with_index.first);
......@@ -141,7 +151,7 @@ BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraphPtr &graph
MS_EXCEPTION_IF_NULL(cnode);
VectorRef ret;
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
auto out = CreateTensorForOutput(cnode->input(i), graph, input_tensors);
auto out = CreateNodeOutputTensors(cnode->input(i), graph, input_tensors, tensor_to_node);
ret.push_back(out);
}
return ret;
......@@ -151,7 +161,7 @@ BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraphPtr &graph
if (size == 0) {
return VectorRef();
}
return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors);
return CreateNodeOutputTensor(item_with_index, graph, input_tensors, tensor_to_node);
}
ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
......@@ -321,6 +331,12 @@ bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) {
GraphId SessionBasic::graph_sum_ = 0;
void SessionBasic::InitDevice(const std::string &device_name, uint32_t device_id) {
device_id_ = device_id;
context_ = std::make_shared<Context>(device_name, device_id);
executor_ = ExecutorManager::Instance().GetExecutor(device_name, device_id);
}
KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const {
auto it = graphs_.find(graph_id);
if (it == graphs_.end()) {
......@@ -982,11 +998,36 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
const std::vector<tensor::TensorPtr> &input_tensors) const {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(outputs);
std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node;
auto anf_outputs = kernel_graph->outputs();
for (auto &item : anf_outputs) {
MS_EXCEPTION_IF_NULL(item);
MS_LOG(INFO) << "Update output[" << item->DebugString() << "]";
outputs->emplace_back(CreateTensorForOutput(item, kernel_graph, input_tensors));
outputs->emplace_back(CreateNodeOutputTensors(item, kernel_graph, input_tensors, &tensor_to_node));
}
for (auto &item : tensor_to_node) {
auto &tensor = item.first;
auto &node = item.second.first;
auto &output_index = item.second.second;
auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
tensor->set_device_address(address);
tensor->SetNeedWait(false);
}
}
void SessionBasic::CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors,
VectorRef *outputs,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
auto kernel_graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(outputs);
MS_EXCEPTION_IF_NULL(tensor_to_node);
auto anf_outputs = kernel_graph->outputs();
for (auto &item : anf_outputs) {
MS_EXCEPTION_IF_NULL(item);
MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]";
outputs->emplace_back(CreateNodeOutputTensors(item, kernel_graph, input_tensors, tensor_to_node));
}
}
......@@ -1231,32 +1272,6 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
return graph;
}
BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) {
if (utils::isa<VectorRef>(base_ref)) {
auto ref_list = utils::cast<VectorRef>(base_ref);
py::tuple output_tensors(ref_list.size());
for (size_t i = 0; i < ref_list.size(); ++i) {
auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef
if (utils::isa<tensor::TensorPtr>(output)) {
auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
MS_EXCEPTION_IF_NULL(tensor_ptr);
output_tensors[i] = tensor_ptr;
} else if (utils::isa<PyObjectRef>(output)) {
py::object obj = utils::cast<PyObjectRef>(output).object_;
py::tuple tensor_tuple = py::cast<py::tuple>(obj);
output_tensors[i] = tensor_tuple;
} else {
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
}
}
return output_tensors; // turn tuple to py::object and store in PyObjectRef
} else if (utils::isa<tensor::TensorPtr>(base_ref)) {
return base_ref;
} else {
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
}
}
KernelGraphPtr SessionBasic::NewKernelGraph() {
auto graph = std::make_shared<KernelGraph>();
graph->set_graph_id(graph_sum_);
......@@ -1281,6 +1296,40 @@ AnfNodePtr SessionBasic::FindPullNode(const AnfNodePtr &push_node, const std::ve
return nullptr;
}
GraphId SessionBasic::CompileGraphAsync(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
MS_EXCEPTION_IF_NULL(executor_);
return executor_->CompileGraphAsync(shared_from_this(), lst, outputs);
}
GraphId SessionBasic::CompileGraphAsync(NotNull<FuncGraphPtr> func_graph) {
MS_EXCEPTION_IF_NULL(executor_);
return executor_->CompileGraphAsync(shared_from_this(), func_graph);
}
void SessionBasic::BuildGraphAsync(GraphId graph_id) {
MS_EXCEPTION_IF_NULL(executor_);
executor_->BuildGraphAsync(shared_from_this(), graph_id);
}
void SessionBasic::BuildOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int> &tensors_mask) {
MS_EXCEPTION_IF_NULL(executor_);
executor_->BuildOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors, tensors_mask);
}
py::tuple SessionBasic::RunOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors) {
MS_EXCEPTION_IF_NULL(executor_);
return executor_->RunOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors);
}
void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(executor_);
executor_->RunGraphAsync(shared_from_this(), graph_id, inputs, outputs);
}
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
if (!parallel::ps::Util::IsRoleOfWorker()) {
......
......@@ -22,10 +22,10 @@
#include <utility>
#include <memory>
#include <map>
#include "utils/base_ref_extends.h"
#include "backend/session/session_context.h"
#include "backend/session/kernel_graph.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/anf.h"
#include "ir/tensor.h"
#include "utils/any.h"
......@@ -49,8 +49,8 @@ using AnyListPtr = std::shared_ptr<AnyList>;
using OpRunInfo = pynative::OpExecInfo;
using OpRunInfoPtr = std::shared_ptr<OpRunInfo>;
class SessionBasic {
class Executor;
class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
public:
SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) {
#ifdef ENABLE_DEBUGGER
......@@ -60,6 +60,12 @@ class SessionBasic {
virtual void Init(uint32_t device_id) { device_id_ = device_id; }
void InitDevice(const std::string &device_name, uint32_t device_id);
virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors,
VectorRef *outputs,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node);
virtual ~SessionBasic() { summary_callback_ = nullptr; }
virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0;
......@@ -69,12 +75,19 @@ class SessionBasic {
virtual void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) = 0;
virtual void BuildOp(const OpRunInfo &, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int> &tensors_mask) {}
virtual void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) {}
virtual py::tuple RunOp(const OpRunInfo &, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors) {
return py::tuple();
}
virtual void RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {}
GraphId CompileGraphAsync(const AnfNodePtrList &lst, const AnfNodePtrList &outputs);
GraphId CompileGraphAsync(NotNull<FuncGraphPtr> func_graph);
void BuildGraphAsync(GraphId graphId);
void RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
void BuildOpAsync(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int> &tensors_mask);
py::tuple RunOpAsync(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors);
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
......@@ -116,9 +129,11 @@ class SessionBasic {
void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs);
protected:
virtual void SetSummaryNodes(KernelGraph *graph);
// Get graph by graph id ,if not exist return null ptr
KernelGraphPtr GetGraph(GraphId graph_id) const;
virtual void SetSummaryNodes(KernelGraph *graph);
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const;
void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
......@@ -132,8 +147,6 @@ class SessionBasic {
std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int> &tensors_mask);
// trans BaseRef list to py::tuple
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref);
// create a new kernel graph and update the graph sum
KernelGraphPtr NewKernelGraph();
std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph);
......@@ -152,6 +165,7 @@ class SessionBasic {
CallBackFunc summary_callback_;
static GraphId graph_sum_;
uint32_t device_id_;
std::shared_ptr<Executor> executor_;
#ifdef ENABLE_DEBUGGER
std::shared_ptr<Debugger> debugger_;
#endif
......
......@@ -37,6 +37,7 @@
#include "frontend/parallel/context.h"
#include "frontend/parallel/graph_util/get_parallel_info.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "backend/session/executor_manager.h"
#include "debug/trace.h"
#include "pipeline/pynative/pynative_execute.h"
#include "frontend/optimizer/py_pass_manager.h"
......@@ -1023,7 +1024,6 @@ void ClearResAtexit() {
MS_LOG(DEBUG) << "Pipeline clear all resource";
pynative::ClearPyNativeSession();
session::ClearPythonParasMap();
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (mindspore::parallel::ps::Util::IsParamServerMode()) {
if (parallel::ps::Util::IsRoleOfWorker()) {
......@@ -1047,6 +1047,8 @@ void ClearResAtexit() {
#else
ConfigManager::GetInstance().ResetIterNum();
#endif
session::ExecutorManager::Instance().Clear();
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
ReleaseGeTsd();
parse::python_adapter::ResetPythonScope();
}
......
......@@ -574,9 +574,9 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors);
// get graph info for checking it whether existing in the cache
std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors);
session->BuildOp(*op_exec_info, graph_info, input_tensors, tensors_mask);
session->BuildOpAsync(op_exec_info.get(), graph_info, input_tensors, tensors_mask);
EraseValueNodeTensor(tensors_mask, &input_tensors);
py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors);
py::tuple result = session->RunOpAsync(op_exec_info.get(), graph_info, input_tensors);
ms_context->set_enable_pynative_infer(false);
*status = PYNATIVE_SUCCESS;
MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms";
......
......@@ -94,7 +94,18 @@ std::string GetRankId() {
AscendKernelRuntime::~AscendKernelRuntime() { graph_model_map_.clear(); }
void AscendKernelRuntime::SetContext() {
if (rt_context_ == nullptr) {
return;
}
auto ret = rtCtxSetCurrent(rt_context_);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]";
}
}
void AscendKernelRuntime::ClearGraphModelMap() {
SetContext();
for (auto &iter : graph_data_dumper_) {
MS_LOG(INFO) << "[DataDump] Unload data dumper:" << iter.first;
auto &data_dumper = iter.second;
......@@ -118,6 +129,7 @@ void AscendKernelRuntime::ClearGraphModelMap() {
void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &,
const std::unordered_set<ValueNodePtr> &,
const std::vector<CNodePtr> &) {
SetContext();
MS_LOG(DEBUG) << "Clear graph:" << graph_id << " data dumper";
if (auto dumper_iter = graph_data_dumper_.find(graph_id); dumper_iter != graph_data_dumper_.end()) {
MS_LOG(DEBUG) << "Unload dump info " << graph_id;
......@@ -156,6 +168,10 @@ bool AscendKernelRuntime::NeedDestroyHccl() {
void AscendKernelRuntime::ReleaseDeviceRes() {
MS_LOG(INFO) << "Ascend finalize start";
if (!initialized_) {
return;
}
SetContext();
// release ge runtime
ClearGraphModelMap();
......@@ -438,6 +454,7 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
}
bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
SetContext();
if (graph == nullptr) {
MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!";
}
......@@ -494,6 +511,7 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
}
bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
SetContext();
if (graph == nullptr) {
MS_EXCEPTION(NotExistsError) << "Null pointer graph, LoadTask failed. ";
}
......@@ -586,6 +604,7 @@ void AscendKernelRuntime::DebugTaskIdName(GraphId graph_id) {
}
bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
SetContext();
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "RunTask start. GraphId:" << graph->graph_id();
......@@ -613,6 +632,7 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
}
bool AscendKernelRuntime::SyncStream() {
SetContext();
if (RT_ERROR_NONE != rtStreamSynchronize(stream_)) { // o for switch stream
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
return false;
......@@ -649,12 +669,7 @@ bool AscendKernelRuntime::InitDevice() {
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]";
}
ret = rtCtxSetCurrent(rt_context_);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]";
}
SetContext();
ret = rtStreamCreate(&stream_, 0);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
......@@ -664,14 +679,9 @@ bool AscendKernelRuntime::InitDevice() {
}
bool AscendKernelRuntime::ResetDevice() {
auto ret = rtCtxSetCurrent(rt_context_);
if (ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Call rtCtxSetCurrent failed";
return false;
}
SetContext();
if (stream_ != nullptr) {
ret = rtStreamDestroy(stream_);
auto ret = rtStreamDestroy(stream_);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]";
}
......@@ -679,7 +689,7 @@ bool AscendKernelRuntime::ResetDevice() {
}
if (rt_context_ != nullptr) {
ret = rtCtxDestroy(rt_context_);
auto ret = rtCtxDestroy(rt_context_);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxDestroy, ret[" << ret << "]";
}
......
......@@ -60,6 +60,7 @@ class AscendKernelRuntime : public KernelRuntime {
bool HcclInit();
bool NeedDestroyHccl();
bool DestroyHccl();
void SetContext();
void ClearGraphModelMap();
void ReleaseDeviceRes() override;
......
......@@ -19,6 +19,7 @@
#include <memory>
#include <numeric>
#include <utility>
#include <functional>
#include "backend/kernel_compiler/kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_context.h"
......@@ -137,10 +138,8 @@ DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t
}
tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node,
size_t index,
std::vector<tensor::TensorPtr> *need_sync_outputs) {
size_t index) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(need_sync_outputs);
size_t output_size = AnfAlgo::GetOutputTensorNum(node);
if (index >= output_size) {
MS_LOG(EXCEPTION) << "Invalid input index " << index;
......@@ -163,16 +162,15 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k
}
if (bound_addresses_.find(address) != bound_addresses_.end()) {
tensor->set_device_address(address);
need_sync_outputs->emplace_back(tensor);
tensor->set_need_sync(true);
} else {
if (infer_type_id != device_type_id) {
size_t type_size = GetTypeByte(TypeIdToType(device_type_id));
ShapeVector data_shape = tensor->shape();
size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>());
address->ptr_ = resource_manager_.MemMalloc(tensor_size);
need_sync_outputs->emplace_back(tensor);
tensor->set_device_address(address);
need_sync_outputs->emplace_back(tensor);
tensor->set_need_sync(true);
} else {
tensor->set_device_address(nullptr);
address->ptr_ = tensor->data_c();
......@@ -185,8 +183,7 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k
}
BaseRef CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *kernel_graph,
const session::KernelWithIndex &kernel_with_index,
std::vector<tensor::TensorPtr> *need_sync_outputs) {
const session::KernelWithIndex &kernel_with_index) {
auto &input_node = kernel_with_index.first;
auto index = kernel_with_index.second;
MS_EXCEPTION_IF_NULL(input_node);
......@@ -197,12 +194,12 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *kernel_grap
VectorRef ret;
for (size_t i = 1; i < node->inputs().size(); i++) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node->input(i), 0);
auto out = CreatTensorForOutput(kernel_graph, item_with_index, need_sync_outputs);
auto out = CreatTensorForOutput(kernel_graph, item_with_index);
ret.push_back(out);
}
return ret;
}
return CreatTensorForOutput(kernel_graph, node, index, need_sync_outputs);
return CreatTensorForOutput(kernel_graph, node, index);
} else if (input_node->isa<Parameter>()) {
auto iter = input_param_tensor_map_.find(input_node);
if (iter != input_param_tensor_map_.end()) {
......@@ -216,7 +213,7 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *kernel_grap
return BaseRef();
}
void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs, std::vector<tensor::TensorPtr> *need_sync_outputs) {
VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(outputs);
// bind input ptr
......@@ -262,7 +259,7 @@ void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const
auto output_nodes = kernel_graph->outputs();
for (const auto &item : output_nodes) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true);
auto out = CreatTensorForOutput(kernel_graph, item_with_index, need_sync_outputs);
auto out = CreatTensorForOutput(kernel_graph, item_with_index);
outputs->push_back(std::move(out));
}
}
......
......@@ -39,7 +39,7 @@ class CPUKernelRuntime : public KernelRuntime {
bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
void AssignKernelAddress(session::KernelGraph *kernel_graph);
void BindInputOutput(session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs, std::vector<tensor::TensorPtr> *need_sync_outputs);
VectorRef *outputs);
void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
......@@ -49,11 +49,9 @@ class CPUKernelRuntime : public KernelRuntime {
TypeId type_id) override;
private:
tensor::TensorPtr CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index,
std::vector<tensor::TensorPtr> *need_sync_outputs);
tensor::TensorPtr CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index);
BaseRef CreatTensorForOutput(session::KernelGraph *kernel_graph, const session::KernelWithIndex &kernel_with_index,
std::vector<tensor::TensorPtr> *need_sync_outputs);
BaseRef CreatTensorForOutput(session::KernelGraph *kernel_graph, const session::KernelWithIndex &kernel_with_index);
void AssignValueNodeAddress(session::KernelGraph *kernel_graph);
void AssignInputNodeAddress(const session::KernelGraph *kernel_graph);
void AssignKernelOutputAddress(const session::KernelGraph *kernel_graph);
......
......@@ -49,8 +49,13 @@ void KernelRuntimeManager::Register(const std::string &device_name, KernelRuntim
}
}
std::string KernelRuntimeManager::GetDeviceKey(const std::string &device_name, uint32_t device_id) {
std::string device_key = device_name + "_" + std::to_string(device_id);
return device_key;
}
KernelRuntime *KernelRuntimeManager::GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id) {
std::string runtime_key = device_name + "_" + std::to_string(device_id);
auto runtime_key = GetDeviceKey(device_name, device_id);
auto runtime_iter = runtime_map_.find(runtime_key);
if (runtime_iter != runtime_map_.end()) {
return runtime_iter->second.get();
......@@ -72,8 +77,8 @@ KernelRuntime *KernelRuntimeManager::GetSingleKernelRuntime(const std::string &d
}
KernelRuntime *KernelRuntimeManager::GetKernelRuntime(const std::string &device_name, uint32_t device_id) {
std::string runtime_key = GetDeviceKey(device_name, device_id);
std::lock_guard<std::mutex> guard(lock_);
std::string runtime_key = device_name + "_" + std::to_string(device_id);
auto runtime_iter = runtime_map_.find(runtime_key);
if (runtime_iter != runtime_map_.end()) {
return runtime_iter->second.get();
......@@ -92,5 +97,20 @@ KernelRuntime *KernelRuntimeManager::GetKernelRuntime(const std::string &device_
return kernel_runtime.get();
}
void KernelRuntimeManager::ReleaseKernelRuntime(const std::string &device_name, uint32_t device_id) {
std::string runtime_key = GetDeviceKey(device_name, device_id);
std::lock_guard<std::mutex> guard(lock_);
auto runtime_iter = runtime_map_.find(runtime_key);
if (runtime_iter == runtime_map_.end()) {
return;
}
auto runtime = runtime_iter->second.get();
if (runtime == nullptr) {
return;
}
runtime->ReleaseDeviceRes();
runtime_map_.erase(runtime_iter);
}
} // namespace device
} // namespace mindspore
......@@ -39,6 +39,7 @@ class KernelRuntimeManager {
void Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator);
KernelRuntime *GetKernelRuntime(const std::string &device_name, uint32_t device_id);
KernelRuntime *GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id);
void ReleaseKernelRuntime(const std::string &device_name, uint32_t device_id);
void ClearRuntimeResource();
void ClearGraphResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,
const std::unordered_set<ValueNodePtr> &value_nodes,
......@@ -48,6 +49,7 @@ class KernelRuntimeManager {
KernelRuntimeManager() = default;
~KernelRuntimeManager() = default;
DISABLE_COPY_AND_ASSIGN(KernelRuntimeManager);
std::string GetDeviceKey(const std::string &device_name, uint32_t device_id);
std::map<std::string, std::shared_ptr<KernelRuntime> > runtime_map_;
std::map<std::string, KernelRuntimeCreator> runtime_creators_;
std::mutex lock_;
......
......@@ -54,9 +54,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri
GraphId graph_id = kInvalidGraphId;
if (target != target_device_ && !target.empty()) {
CreateOtherSession(target);
graph_id = other_sess_->CompileGraph(lst, outputs);
graph_id = other_sess_->CompileGraphAsync(lst, outputs);
} else {
graph_id = target_sess_->CompileGraph(lst, outputs);
graph_id = target_sess_->CompileGraphAsync(lst, outputs);
}
if (MsContext::GetInstance()->precompile_only()) {
......@@ -64,9 +64,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri
return result;
}
if (target != target_device_ && !target.empty()) {
other_sess_->BuildGraph(graph_id);
other_sess_->BuildGraphAsync(graph_id);
} else if (!is_multi_graph_sink_) {
target_sess_->BuildGraph(graph_id);
target_sess_->BuildGraphAsync(graph_id);
}
result.run = std::make_shared<RunFunc>(
[graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); });
......@@ -137,9 +137,9 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s
VectorRef outputs;
// call ms rungraph (graphId, input ,output)
if (target != target_device_ && !target.empty()) {
other_sess_->RunGraph(g, inputs, &outputs);
other_sess_->RunGraphAsync(g, inputs, &outputs);
} else {
target_sess_->RunGraph(g, inputs, &outputs);
target_sess_->RunGraphAsync(g, inputs, &outputs);
}
MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size();
......@@ -150,7 +150,7 @@ void MsBackend::Link(GraphId graph_id) {
if (graph_id == kInvalidGraphId) {
graph_id = target_sess_->GetFinalRunGraph();
}
target_sess_->BuildGraph(graph_id);
target_sess_->BuildGraphAsync(graph_id);
}
Backend::Backend(const std::string &name) : name_(name) {
......@@ -186,7 +186,7 @@ void MsBackend::CreateOtherSession(const std::string &target) {
other_device_ = target;
}
GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return target_sess_->CompileGraph(fg); }
GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return target_sess_->CompileGraphAsync(fg); }
VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); }
......
......@@ -424,6 +424,8 @@ Tensor::Tensor(const Tensor &tensor)
data_(tensor.data_),
dirty_(tensor.dirty_),
id_(tensor.id_),
event_(tensor.event_),
need_sync_(tensor.need_sync_),
device_sync_(tensor.device_sync_),
padding_type_(tensor.padding_type()) {}
......@@ -433,6 +435,8 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)),
dirty_(tensor.dirty_),
id_(tensor.id_),
event_(tensor.event_),
need_sync_(tensor.need_sync_),
device_sync_(tensor.device_sync_),
padding_type_(tensor.padding_type()) {}
......@@ -483,6 +487,8 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) {
device_sync_ = tensor.device_sync_;
data_ = tensor.data_;
id_ = tensor.id_;
event_ = tensor.event_;
need_sync_ = tensor.need_sync_;
padding_type_ = tensor.padding_type_;
}
return *this;
......@@ -547,6 +553,7 @@ std::string Tensor::ToStringRepr() const {
}
void Tensor::data_sync() const {
const_cast<Tensor *>(this)->Wait();
if (device_sync_ != nullptr) {
if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) {
MS_LOG(EXCEPTION) << "SyncDeviceToHost failed.";
......
......@@ -21,6 +21,8 @@
#include <string>
#include <vector>
#include <numeric>
#include <mutex>
#include <condition_variable>
#include "ir/device_sync.h"
#include "ir/meta_tensor.h"
......@@ -73,6 +75,30 @@ class TensorData {
using TensorDataPtr = std::shared_ptr<TensorData>;
struct WaitEvent {
bool need_wait_{false};
std::mutex mutex_;
std::condition_variable cond_var_;
void Wait() {
std::unique_lock<std::mutex> lock(mutex_);
if (!need_wait_) {
return;
}
cond_var_.wait(lock, [this] { return !need_wait_; });
}
void set_need_wait(bool need_wait) {
std::unique_lock<std::mutex> lock(mutex_);
need_wait_ = need_wait;
if (!need_wait_) {
cond_var_.notify_all();
}
}
bool need_wait() const { return need_wait_; }
};
// Tensor entity class
class Tensor : public MetaTensor {
public:
......@@ -244,11 +270,40 @@ class Tensor : public MetaTensor {
std::string id() const { return id_; }
void SetNeedWait(bool need_wait) {
if (event_ != nullptr) {
event_->set_need_wait(need_wait);
} else if (need_wait) {
event_ = std::make_shared<WaitEvent>();
event_->set_need_wait(need_wait);
}
}
bool NeedWait() const {
if (event_ != nullptr) {
return event_->need_wait();
}
return false;
}
void Wait() {
if (event_ != nullptr) {
event_->Wait();
}
event_ == nullptr;
}
void set_need_sync(bool need_sync) { need_sync_ = need_sync; }
bool need_sync() const { return need_sync_; }
private:
bool init_flag_{false};
TensorDataPtr data_{nullptr};
bool dirty_{true};
std::string id_{""};
std::shared_ptr<WaitEvent> event_{nullptr};
bool need_sync_{false};
DeviceSyncPtr device_sync_{nullptr};
std::vector<Axis> padding_type_;
};
......
......@@ -84,6 +84,8 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/backend/session/ascend_control_parser.cc"
"../../../mindspore/ccsrc/backend/session/kernel_graph.cc"
"../../../mindspore/ccsrc/backend/session/session_basic.cc"
"../../../mindspore/ccsrc/backend/session/executor.cc"
"../../../mindspore/ccsrc/backend/session/executor_manager.cc"
"../../../mindspore/ccsrc/backend/session/session_factory.cc"
"../../../mindspore/ccsrc/backend/session/kernel_build_client.cc"
"../../../mindspore/ccsrc/vm/*.cc"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册