提交 ebff566a 编写于 作者: K kswang

add group operation for executor

上级 bc4c5afc
......@@ -16,6 +16,7 @@
#include "backend/session/executor.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "backend/session/executor_manager.h"
#include "utils/comm_manager.h"
namespace mindspore {
namespace session {
......@@ -45,32 +46,6 @@ void UpdateOutputTensors(VectorRef *outputs,
}
}
}
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) {
if (utils::isa<VectorRef>(base_ref)) {
auto ref_list = utils::cast<VectorRef>(base_ref);
py::tuple output_tensors(ref_list.size());
for (size_t i = 0; i < ref_list.size(); ++i) {
auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef
if (utils::isa<tensor::TensorPtr>(output)) {
auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
MS_EXCEPTION_IF_NULL(tensor_ptr);
output_tensors[i] = tensor_ptr;
} else if (utils::isa<PyObjectRef>(output)) {
py::object obj = utils::cast<PyObjectRef>(output).object_;
py::tuple tensor_tuple = py::cast<py::tuple>(obj);
output_tensors[i] = tensor_tuple;
} else {
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
}
}
return output_tensors; // turn tuple to py::object and store in PyObjectRef
} else if (utils::isa<tensor::TensorPtr>(base_ref)) {
return base_ref;
} else {
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
}
}
} // namespace
void CompileNodesTask::Run() {
MS_EXCEPTION_IF_NULL(session_);
......@@ -104,6 +79,10 @@ void RunOpTask::Run() {
session_->RunOp(*op_run_info_, graph_info_, input_tensors_, &outputs_);
}
void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); }
void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); }
Executor::Executor(const std::string &device_name, uint32_t device_id) {
device_name_ = device_name;
device_id_ = device_id;
......@@ -141,22 +120,8 @@ void Executor::WorkerLoop() {
} catch (const std::exception &e) {
exception_ptr_ = std::current_exception();
}
auto task_type = task->type_;
task = nullptr;
if (task_type == kCompileNodes) {
compile_cond_var_.notify_all();
} else if (task_type == kCompileGraph) {
compile_cond_var_.notify_all();
} else if (task_type == kBuildGraph) {
build_cond_var_.notify_all();
} else if (task_type == kRunGraph) {
run_cond_var_.notify_all();
} else if (task_type == kBuildOp) {
build_op_cond_var_.notify_all();
} else if (task_type == kRunOp) {
run_op_cond_var_.notify_all();
}
sync_cond_var_.notify_all();
}
}
......@@ -206,7 +171,7 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, const AnfNodePtrL
task->output_nodes_ = outputs;
ready_tasks_.push(task);
task_cond_var_.notify_all();
compile_cond_var_.wait(lock);
sync_cond_var_.wait(lock);
CheckException();
return task->graph_id_;
}
......@@ -219,7 +184,7 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, NotNull<FuncGraph
task->func_graph_ = func_graph;
ready_tasks_.push(task);
task_cond_var_.notify_all();
compile_cond_var_.wait(lock);
sync_cond_var_.wait(lock);
CheckException();
return task->graph_id_;
}
......@@ -232,7 +197,7 @@ void Executor::BuildGraphAsync(const SessionPtr &session, GraphId graphId) {
task->graph_id_ = graphId;
ready_tasks_.push(task);
task_cond_var_.notify_all();
build_cond_var_.wait(lock);
sync_cond_var_.wait(lock);
CheckException();
}
......@@ -258,7 +223,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
ready_tasks_.push(task);
task_cond_var_.notify_all();
py::gil_scoped_release release;
run_cond_var_.wait(lock);
sync_cond_var_.wait(lock);
CheckException();
}
......@@ -274,12 +239,12 @@ void Executor::BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, c
task->tensors_mask_ = tensors_mask;
ready_tasks_.push(task);
task_cond_var_.notify_all();
build_op_cond_var_.wait(lock);
sync_cond_var_.wait(lock);
CheckException();
}
py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors) {
void Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
CheckException();
std::unique_lock<std::mutex> lock(task_mutex_);
auto task = std::make_shared<RunOpTask>();
......@@ -289,18 +254,30 @@ py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info
task->input_tensors_ = input_tensors;
ready_tasks_.push(task);
task_cond_var_.notify_all();
run_op_cond_var_.wait(lock);
sync_cond_var_.wait(lock);
CheckException();
*outputs = task->outputs_;
}
// Trans output to tuple
auto output_tensors = TransformBaseRefListToTuple(task->outputs_);
if (!utils::isa<PyObjectRef>(output_tensors) ||
!py::isinstance<py::tuple>(utils::cast<PyObjectRef>(output_tensors).object_)) {
MS_EXCEPTION(NotSupportError) << "The output tensors should be a tuple !";
}
py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_;
py::tuple tuple_tensors = py::cast<py::tuple>(tuple_obj);
return tuple_tensors;
bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) {
std::unique_lock<std::mutex> lock(task_mutex_);
auto task = std::make_shared<CreateCommGroupTask>();
task->group_name_ = group_name;
task->ranks_ = ranks;
ready_tasks_.push(task);
task_cond_var_.notify_all();
sync_cond_var_.wait(lock);
return task->result_;
}
bool Executor::DestroyCommGroup(const std::string &group_name) {
std::unique_lock<std::mutex> lock(task_mutex_);
auto task = std::make_shared<DestroyCommGroupTask>();
task->group_name_ = group_name;
ready_tasks_.push(task);
task_cond_var_.notify_all();
sync_cond_var_.wait(lock);
return task->result_;
}
void Executor::StopWorker() {
......
......@@ -32,10 +32,22 @@
#include "ir/tensor.h"
#include "utils/any.h"
#include "utils/contract.h"
#include "utils/comm_manager.h"
namespace mindspore {
namespace session {
enum TaskType { kUnKnown, kExit, kCompileNodes, kCompileGraph, kBuildGraph, kBuildOp, kRunGraph, kRunOp };
enum TaskType {
kUnKnown,
kExit,
kCompileNodes,
kCompileGraph,
kBuildGraph,
kBuildOp,
kRunGraph,
kRunOp,
kCreateCommGroup,
kDestroyCommGroup
};
class Task {
public:
......@@ -106,6 +118,25 @@ class RunOpTask : public Task {
VectorRef outputs_;
};
class CreateCommGroupTask : public Task {
public:
CreateCommGroupTask() { type_ = kCreateCommGroup; }
~CreateCommGroupTask() override = default;
void Run() override;
std::string group_name_;
std::vector<uint32_t> ranks_;
bool result_;
};
class DestroyCommGroupTask : public Task {
public:
DestroyCommGroupTask() { type_ = kDestroyCommGroup; }
~DestroyCommGroupTask() override = default;
void Run() override;
std::string group_name_;
bool result_;
};
class ExitTask : public Task {
public:
ExitTask() { type_ = kExit; }
......@@ -125,9 +156,11 @@ class Executor {
VectorRef *outputs);
void BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask);
py::tuple RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors);
void RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs);
void OnRunGraphFinished();
bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks);
bool DestroyCommGroup(const std::string &group_name);
private:
void UpdateOutputTensors(VectorRef *outputs,
......@@ -143,11 +176,7 @@ class Executor {
std::mutex task_mutex_;
std::mutex pending_task_mutex_;
std::condition_variable task_cond_var_;
std::condition_variable compile_cond_var_;
std::condition_variable build_cond_var_;
std::condition_variable run_cond_var_;
std::condition_variable build_op_cond_var_;
std::condition_variable run_op_cond_var_;
std::condition_variable sync_cond_var_;
std::queue<std::shared_ptr<Task>> ready_tasks_;
std::list<std::shared_ptr<RunGraphTask>> pending_tasks_;
std::shared_ptr<std::thread> worker_;
......
......@@ -1344,10 +1344,10 @@ void SessionBasic::BuildOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_i
executor_->BuildOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors, tensors_mask);
}
py::tuple SessionBasic::RunOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors) {
void SessionBasic::RunOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(executor_);
return executor_->RunOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors);
executor_->RunOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors, outputs);
}
void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
......
......@@ -90,7 +90,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
void RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
void BuildOpAsync(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int> &tensors_mask);
py::tuple RunOpAsync(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors);
void RunOpAsync(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors,
VectorRef *outputs);
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
......
......@@ -15,12 +15,12 @@
*/
#include "frontend/parallel/group_manager.h"
#include <algorithm>
#include <vector>
#include "frontend/parallel/device_manager.h"
#include "backend/session/executor_manager.h"
#include "utils/comm_manager.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace parallel {
......@@ -96,8 +96,14 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto
vector<uint32_t> ranks;
(void)std::transform(std::begin(devices), std::end(devices), std::back_inserter(ranks),
[](const Device dev) { return (uint32_t)dev.rank(); });
// Create group through the CommManager interface
bool ret = CommManager::GetInstance().CreateGroupSync(group_name, ranks);
// Create group through the executor
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
MS_EXCEPTION_IF_NULL(executor);
bool ret = executor->CreateCommGroup(group_name, ranks);
if (!ret) {
MS_LOG(ERROR) << "Create group failed, group name is " << group_name;
return Status::FAILED;
......@@ -108,6 +114,20 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto
}
}
Status GroupManager::DestroyGroup(const std::string &group_name) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
MS_EXCEPTION_IF_NULL(executor);
bool ret = executor->DestroyCommGroup(group_name);
if (!ret) {
return Status::FAILED;
}
return Status::SUCCESS;
}
Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) {
std::string name = (*group).name();
auto it = groups_.find(name);
......@@ -116,18 +136,14 @@ Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) {
return Status::FAILED;
}
(void)groups_.erase(it);
bool ret = CommManager::GetInstance().DestroyGroup(name);
if (!ret) {
return Status::FAILED;
}
return Status::SUCCESS;
return DestroyGroup(name);
}
Status GroupManager::DestroyAllGroups() {
for (auto &it : groups_) {
std::string name = it.first;
bool ret = CommManager::GetInstance().DestroyGroup(name);
if (!ret) {
auto ret = DestroyGroup(name);
if (ret != Status::SUCCESS) {
return Status::FAILED;
}
}
......
......@@ -65,6 +65,7 @@ class GroupManager {
void Clear();
private:
Status DestroyGroup(const std::string &group_name);
// the key is group name (name_)
std::map<std::string, Group> groups_;
std::string world_group_;
......
......@@ -19,18 +19,22 @@
#include <typeinfo>
#include <map>
#include <set>
#include <memory>
#include <unordered_set>
#include <algorithm>
#include "debug/trace.h"
#include "pybind_api/ir/tensor_py.h"
#include "ir/param_info.h"
#include "ir/anf.h"
#include "ir/tensor.h"
#include "utils/any.h"
#include "utils/utils.h"
#include "utils/ms_context.h"
#include "utils/context/context_extends.h"
#include "utils/config_manager.h"
#include "utils/convert_utils_py.h"
#include "utils/base_ref_extends.h"
#include "frontend/operator/ops.h"
#include "frontend/operator/composite/composite.h"
#include "frontend/operator/composite/do_signature.h"
......@@ -554,6 +558,32 @@ void EraseValueNodeTensor(const std::vector<int> &tensors_mask, std::vector<tens
*input_tensors = new_input_tensors;
}
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) {
if (utils::isa<VectorRef>(base_ref)) {
auto ref_list = utils::cast<VectorRef>(base_ref);
py::tuple output_tensors(ref_list.size());
for (size_t i = 0; i < ref_list.size(); ++i) {
auto output = TransformBaseRefListToTuple(ref_list[i]);
if (utils::isa<tensor::TensorPtr>(output)) {
auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
MS_EXCEPTION_IF_NULL(tensor_ptr);
output_tensors[i] = tensor_ptr;
} else if (utils::isa<PyObjectRef>(output)) {
py::object obj = utils::cast<PyObjectRef>(output).object_;
py::tuple tensor_tuple = py::cast<py::tuple>(obj);
output_tensors[i] = tensor_tuple;
} else {
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
}
}
return std::make_shared<PyObjectRef>(output_tensors);
} else if (utils::isa<tensor::TensorPtr>(base_ref)) {
return base_ref;
} else {
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
}
}
py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
MS_EXCEPTION_IF_NULL(op_exec_info);
MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms";
......@@ -577,7 +607,19 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors);
session->BuildOpAsync(op_exec_info.get(), graph_info, input_tensors, tensors_mask);
EraseValueNodeTensor(tensors_mask, &input_tensors);
py::tuple result = session->RunOpAsync(op_exec_info.get(), graph_info, input_tensors);
VectorRef outputs;
session->RunOpAsync(op_exec_info.get(), graph_info, input_tensors, &outputs);
// Trans output to tuple
auto output_tensors = TransformBaseRefListToTuple(outputs);
if (!utils::isa<PyObjectRef>(output_tensors) ||
!py::isinstance<py::tuple>(utils::cast<PyObjectRef>(output_tensors).object_)) {
MS_EXCEPTION(NotSupportError) << "The output tensors should be a tuple !";
}
py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_;
py::tuple result = py::cast<py::tuple>(tuple_obj);
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
*status = PYNATIVE_SUCCESS;
MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册