提交 ebff566a 编写于 作者: K kswang

add group operation for executor

上级 bc4c5afc
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "backend/session/executor.h" #include "backend/session/executor.h"
#include "runtime/device/kernel_runtime_manager.h" #include "runtime/device/kernel_runtime_manager.h"
#include "backend/session/executor_manager.h" #include "backend/session/executor_manager.h"
#include "utils/comm_manager.h"
namespace mindspore { namespace mindspore {
namespace session { namespace session {
...@@ -45,32 +46,6 @@ void UpdateOutputTensors(VectorRef *outputs, ...@@ -45,32 +46,6 @@ void UpdateOutputTensors(VectorRef *outputs,
} }
} }
} }
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) {
if (utils::isa<VectorRef>(base_ref)) {
auto ref_list = utils::cast<VectorRef>(base_ref);
py::tuple output_tensors(ref_list.size());
for (size_t i = 0; i < ref_list.size(); ++i) {
auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef
if (utils::isa<tensor::TensorPtr>(output)) {
auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
MS_EXCEPTION_IF_NULL(tensor_ptr);
output_tensors[i] = tensor_ptr;
} else if (utils::isa<PyObjectRef>(output)) {
py::object obj = utils::cast<PyObjectRef>(output).object_;
py::tuple tensor_tuple = py::cast<py::tuple>(obj);
output_tensors[i] = tensor_tuple;
} else {
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
}
}
return output_tensors; // turn tuple to py::object and store in PyObjectRef
} else if (utils::isa<tensor::TensorPtr>(base_ref)) {
return base_ref;
} else {
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
}
}
} // namespace } // namespace
void CompileNodesTask::Run() { void CompileNodesTask::Run() {
MS_EXCEPTION_IF_NULL(session_); MS_EXCEPTION_IF_NULL(session_);
...@@ -104,6 +79,10 @@ void RunOpTask::Run() { ...@@ -104,6 +79,10 @@ void RunOpTask::Run() {
session_->RunOp(*op_run_info_, graph_info_, input_tensors_, &outputs_); session_->RunOp(*op_run_info_, graph_info_, input_tensors_, &outputs_);
} }
void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); }
void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); }
Executor::Executor(const std::string &device_name, uint32_t device_id) { Executor::Executor(const std::string &device_name, uint32_t device_id) {
device_name_ = device_name; device_name_ = device_name;
device_id_ = device_id; device_id_ = device_id;
...@@ -141,22 +120,8 @@ void Executor::WorkerLoop() { ...@@ -141,22 +120,8 @@ void Executor::WorkerLoop() {
} catch (const std::exception &e) { } catch (const std::exception &e) {
exception_ptr_ = std::current_exception(); exception_ptr_ = std::current_exception();
} }
auto task_type = task->type_;
task = nullptr; task = nullptr;
if (task_type == kCompileNodes) { sync_cond_var_.notify_all();
compile_cond_var_.notify_all();
} else if (task_type == kCompileGraph) {
compile_cond_var_.notify_all();
} else if (task_type == kBuildGraph) {
build_cond_var_.notify_all();
} else if (task_type == kRunGraph) {
run_cond_var_.notify_all();
} else if (task_type == kBuildOp) {
build_op_cond_var_.notify_all();
} else if (task_type == kRunOp) {
run_op_cond_var_.notify_all();
}
} }
} }
...@@ -206,7 +171,7 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, const AnfNodePtrL ...@@ -206,7 +171,7 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, const AnfNodePtrL
task->output_nodes_ = outputs; task->output_nodes_ = outputs;
ready_tasks_.push(task); ready_tasks_.push(task);
task_cond_var_.notify_all(); task_cond_var_.notify_all();
compile_cond_var_.wait(lock); sync_cond_var_.wait(lock);
CheckException(); CheckException();
return task->graph_id_; return task->graph_id_;
} }
...@@ -219,7 +184,7 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, NotNull<FuncGraph ...@@ -219,7 +184,7 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, NotNull<FuncGraph
task->func_graph_ = func_graph; task->func_graph_ = func_graph;
ready_tasks_.push(task); ready_tasks_.push(task);
task_cond_var_.notify_all(); task_cond_var_.notify_all();
compile_cond_var_.wait(lock); sync_cond_var_.wait(lock);
CheckException(); CheckException();
return task->graph_id_; return task->graph_id_;
} }
...@@ -232,7 +197,7 @@ void Executor::BuildGraphAsync(const SessionPtr &session, GraphId graphId) { ...@@ -232,7 +197,7 @@ void Executor::BuildGraphAsync(const SessionPtr &session, GraphId graphId) {
task->graph_id_ = graphId; task->graph_id_ = graphId;
ready_tasks_.push(task); ready_tasks_.push(task);
task_cond_var_.notify_all(); task_cond_var_.notify_all();
build_cond_var_.wait(lock); sync_cond_var_.wait(lock);
CheckException(); CheckException();
} }
...@@ -258,7 +223,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, ...@@ -258,7 +223,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
ready_tasks_.push(task); ready_tasks_.push(task);
task_cond_var_.notify_all(); task_cond_var_.notify_all();
py::gil_scoped_release release; py::gil_scoped_release release;
run_cond_var_.wait(lock); sync_cond_var_.wait(lock);
CheckException(); CheckException();
} }
...@@ -274,12 +239,12 @@ void Executor::BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, c ...@@ -274,12 +239,12 @@ void Executor::BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, c
task->tensors_mask_ = tensors_mask; task->tensors_mask_ = tensors_mask;
ready_tasks_.push(task); ready_tasks_.push(task);
task_cond_var_.notify_all(); task_cond_var_.notify_all();
build_op_cond_var_.wait(lock); sync_cond_var_.wait(lock);
CheckException(); CheckException();
} }
py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, void Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors) { const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
CheckException(); CheckException();
std::unique_lock<std::mutex> lock(task_mutex_); std::unique_lock<std::mutex> lock(task_mutex_);
auto task = std::make_shared<RunOpTask>(); auto task = std::make_shared<RunOpTask>();
...@@ -289,18 +254,30 @@ py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info ...@@ -289,18 +254,30 @@ py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info
task->input_tensors_ = input_tensors; task->input_tensors_ = input_tensors;
ready_tasks_.push(task); ready_tasks_.push(task);
task_cond_var_.notify_all(); task_cond_var_.notify_all();
run_op_cond_var_.wait(lock); sync_cond_var_.wait(lock);
CheckException(); CheckException();
*outputs = task->outputs_;
}
// Trans output to tuple bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) {
auto output_tensors = TransformBaseRefListToTuple(task->outputs_); std::unique_lock<std::mutex> lock(task_mutex_);
if (!utils::isa<PyObjectRef>(output_tensors) || auto task = std::make_shared<CreateCommGroupTask>();
!py::isinstance<py::tuple>(utils::cast<PyObjectRef>(output_tensors).object_)) { task->group_name_ = group_name;
MS_EXCEPTION(NotSupportError) << "The output tensors should be a tuple !"; task->ranks_ = ranks;
} ready_tasks_.push(task);
py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_; task_cond_var_.notify_all();
py::tuple tuple_tensors = py::cast<py::tuple>(tuple_obj); sync_cond_var_.wait(lock);
return tuple_tensors; return task->result_;
}
bool Executor::DestroyCommGroup(const std::string &group_name) {
std::unique_lock<std::mutex> lock(task_mutex_);
auto task = std::make_shared<DestroyCommGroupTask>();
task->group_name_ = group_name;
ready_tasks_.push(task);
task_cond_var_.notify_all();
sync_cond_var_.wait(lock);
return task->result_;
} }
void Executor::StopWorker() { void Executor::StopWorker() {
......
...@@ -32,10 +32,22 @@ ...@@ -32,10 +32,22 @@
#include "ir/tensor.h" #include "ir/tensor.h"
#include "utils/any.h" #include "utils/any.h"
#include "utils/contract.h" #include "utils/contract.h"
#include "utils/comm_manager.h"
namespace mindspore { namespace mindspore {
namespace session { namespace session {
enum TaskType { kUnKnown, kExit, kCompileNodes, kCompileGraph, kBuildGraph, kBuildOp, kRunGraph, kRunOp }; enum TaskType {
kUnKnown,
kExit,
kCompileNodes,
kCompileGraph,
kBuildGraph,
kBuildOp,
kRunGraph,
kRunOp,
kCreateCommGroup,
kDestroyCommGroup
};
class Task { class Task {
public: public:
...@@ -106,6 +118,25 @@ class RunOpTask : public Task { ...@@ -106,6 +118,25 @@ class RunOpTask : public Task {
VectorRef outputs_; VectorRef outputs_;
}; };
class CreateCommGroupTask : public Task {
public:
CreateCommGroupTask() { type_ = kCreateCommGroup; }
~CreateCommGroupTask() override = default;
void Run() override;
std::string group_name_;
std::vector<uint32_t> ranks_;
bool result_;
};
class DestroyCommGroupTask : public Task {
public:
DestroyCommGroupTask() { type_ = kDestroyCommGroup; }
~DestroyCommGroupTask() override = default;
void Run() override;
std::string group_name_;
bool result_;
};
class ExitTask : public Task { class ExitTask : public Task {
public: public:
ExitTask() { type_ = kExit; } ExitTask() { type_ = kExit; }
...@@ -125,9 +156,11 @@ class Executor { ...@@ -125,9 +156,11 @@ class Executor {
VectorRef *outputs); VectorRef *outputs);
void BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, void BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask); const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask);
py::tuple RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, void RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors); const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs);
void OnRunGraphFinished(); void OnRunGraphFinished();
bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks);
bool DestroyCommGroup(const std::string &group_name);
private: private:
void UpdateOutputTensors(VectorRef *outputs, void UpdateOutputTensors(VectorRef *outputs,
...@@ -143,11 +176,7 @@ class Executor { ...@@ -143,11 +176,7 @@ class Executor {
std::mutex task_mutex_; std::mutex task_mutex_;
std::mutex pending_task_mutex_; std::mutex pending_task_mutex_;
std::condition_variable task_cond_var_; std::condition_variable task_cond_var_;
std::condition_variable compile_cond_var_; std::condition_variable sync_cond_var_;
std::condition_variable build_cond_var_;
std::condition_variable run_cond_var_;
std::condition_variable build_op_cond_var_;
std::condition_variable run_op_cond_var_;
std::queue<std::shared_ptr<Task>> ready_tasks_; std::queue<std::shared_ptr<Task>> ready_tasks_;
std::list<std::shared_ptr<RunGraphTask>> pending_tasks_; std::list<std::shared_ptr<RunGraphTask>> pending_tasks_;
std::shared_ptr<std::thread> worker_; std::shared_ptr<std::thread> worker_;
......
...@@ -1344,10 +1344,10 @@ void SessionBasic::BuildOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_i ...@@ -1344,10 +1344,10 @@ void SessionBasic::BuildOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_i
executor_->BuildOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors, tensors_mask); executor_->BuildOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors, tensors_mask);
} }
py::tuple SessionBasic::RunOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_info, void SessionBasic::RunOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors) { const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(executor_); MS_EXCEPTION_IF_NULL(executor_);
return executor_->RunOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors); executor_->RunOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors, outputs);
} }
void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
......
...@@ -90,7 +90,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { ...@@ -90,7 +90,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
void RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); void RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
void BuildOpAsync(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors, void BuildOpAsync(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int> &tensors_mask); const std::vector<int> &tensors_mask);
py::tuple RunOpAsync(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors); void RunOpAsync(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors,
VectorRef *outputs);
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
......
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
*/ */
#include "frontend/parallel/group_manager.h" #include "frontend/parallel/group_manager.h"
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "frontend/parallel/device_manager.h" #include "frontend/parallel/device_manager.h"
#include "backend/session/executor_manager.h"
#include "utils/comm_manager.h" #include "utils/comm_manager.h"
#include "utils/ms_context.h"
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
...@@ -96,8 +96,14 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto ...@@ -96,8 +96,14 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto
vector<uint32_t> ranks; vector<uint32_t> ranks;
(void)std::transform(std::begin(devices), std::end(devices), std::back_inserter(ranks), (void)std::transform(std::begin(devices), std::end(devices), std::back_inserter(ranks),
[](const Device dev) { return (uint32_t)dev.rank(); }); [](const Device dev) { return (uint32_t)dev.rank(); });
// Create group through the CommManager interface // Create group through the executor
bool ret = CommManager::GetInstance().CreateGroupSync(group_name, ranks); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
MS_EXCEPTION_IF_NULL(executor);
bool ret = executor->CreateCommGroup(group_name, ranks);
if (!ret) { if (!ret) {
MS_LOG(ERROR) << "Create group failed, group name is " << group_name; MS_LOG(ERROR) << "Create group failed, group name is " << group_name;
return Status::FAILED; return Status::FAILED;
...@@ -108,6 +114,20 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto ...@@ -108,6 +114,20 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto
} }
} }
Status GroupManager::DestroyGroup(const std::string &group_name) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
MS_EXCEPTION_IF_NULL(executor);
bool ret = executor->DestroyCommGroup(group_name);
if (!ret) {
return Status::FAILED;
}
return Status::SUCCESS;
}
Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) { Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) {
std::string name = (*group).name(); std::string name = (*group).name();
auto it = groups_.find(name); auto it = groups_.find(name);
...@@ -116,18 +136,14 @@ Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) { ...@@ -116,18 +136,14 @@ Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) {
return Status::FAILED; return Status::FAILED;
} }
(void)groups_.erase(it); (void)groups_.erase(it);
bool ret = CommManager::GetInstance().DestroyGroup(name); return DestroyGroup(name);
if (!ret) {
return Status::FAILED;
}
return Status::SUCCESS;
} }
Status GroupManager::DestroyAllGroups() { Status GroupManager::DestroyAllGroups() {
for (auto &it : groups_) { for (auto &it : groups_) {
std::string name = it.first; std::string name = it.first;
bool ret = CommManager::GetInstance().DestroyGroup(name); auto ret = DestroyGroup(name);
if (!ret) { if (ret != Status::SUCCESS) {
return Status::FAILED; return Status::FAILED;
} }
} }
......
...@@ -65,6 +65,7 @@ class GroupManager { ...@@ -65,6 +65,7 @@ class GroupManager {
void Clear(); void Clear();
private: private:
Status DestroyGroup(const std::string &group_name);
// the key is group name (name_) // the key is group name (name_)
std::map<std::string, Group> groups_; std::map<std::string, Group> groups_;
std::string world_group_; std::string world_group_;
......
...@@ -19,18 +19,22 @@ ...@@ -19,18 +19,22 @@
#include <typeinfo> #include <typeinfo>
#include <map> #include <map>
#include <set> #include <set>
#include <memory>
#include <unordered_set> #include <unordered_set>
#include <algorithm> #include <algorithm>
#include "debug/trace.h" #include "debug/trace.h"
#include "pybind_api/ir/tensor_py.h" #include "pybind_api/ir/tensor_py.h"
#include "ir/param_info.h" #include "ir/param_info.h"
#include "ir/anf.h"
#include "ir/tensor.h"
#include "utils/any.h" #include "utils/any.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "utils/context/context_extends.h" #include "utils/context/context_extends.h"
#include "utils/config_manager.h" #include "utils/config_manager.h"
#include "utils/convert_utils_py.h" #include "utils/convert_utils_py.h"
#include "utils/base_ref_extends.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "frontend/operator/composite/composite.h" #include "frontend/operator/composite/composite.h"
#include "frontend/operator/composite/do_signature.h" #include "frontend/operator/composite/do_signature.h"
...@@ -554,6 +558,32 @@ void EraseValueNodeTensor(const std::vector<int> &tensors_mask, std::vector<tens ...@@ -554,6 +558,32 @@ void EraseValueNodeTensor(const std::vector<int> &tensors_mask, std::vector<tens
*input_tensors = new_input_tensors; *input_tensors = new_input_tensors;
} }
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) {
if (utils::isa<VectorRef>(base_ref)) {
auto ref_list = utils::cast<VectorRef>(base_ref);
py::tuple output_tensors(ref_list.size());
for (size_t i = 0; i < ref_list.size(); ++i) {
auto output = TransformBaseRefListToTuple(ref_list[i]);
if (utils::isa<tensor::TensorPtr>(output)) {
auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
MS_EXCEPTION_IF_NULL(tensor_ptr);
output_tensors[i] = tensor_ptr;
} else if (utils::isa<PyObjectRef>(output)) {
py::object obj = utils::cast<PyObjectRef>(output).object_;
py::tuple tensor_tuple = py::cast<py::tuple>(obj);
output_tensors[i] = tensor_tuple;
} else {
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
}
}
return std::make_shared<PyObjectRef>(output_tensors);
} else if (utils::isa<tensor::TensorPtr>(base_ref)) {
return base_ref;
} else {
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
}
}
py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
MS_EXCEPTION_IF_NULL(op_exec_info); MS_EXCEPTION_IF_NULL(op_exec_info);
MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms"; MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms";
...@@ -577,7 +607,19 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat ...@@ -577,7 +607,19 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors); std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors);
session->BuildOpAsync(op_exec_info.get(), graph_info, input_tensors, tensors_mask); session->BuildOpAsync(op_exec_info.get(), graph_info, input_tensors, tensors_mask);
EraseValueNodeTensor(tensors_mask, &input_tensors); EraseValueNodeTensor(tensors_mask, &input_tensors);
py::tuple result = session->RunOpAsync(op_exec_info.get(), graph_info, input_tensors);
VectorRef outputs;
session->RunOpAsync(op_exec_info.get(), graph_info, input_tensors, &outputs);
// Trans output to tuple
auto output_tensors = TransformBaseRefListToTuple(outputs);
if (!utils::isa<PyObjectRef>(output_tensors) ||
!py::isinstance<py::tuple>(utils::cast<PyObjectRef>(output_tensors).object_)) {
MS_EXCEPTION(NotSupportError) << "The output tensors should be a tuple !";
}
py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_;
py::tuple result = py::cast<py::tuple>(tuple_obj);
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false); ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
*status = PYNATIVE_SUCCESS; *status = PYNATIVE_SUCCESS;
MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms"; MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册