提交 4f08d35c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1415 support mix target

Merge pull request !1415 from kisnwang/support-mix-target
...@@ -35,6 +35,7 @@ class AscendDeviceAddress : public DeviceAddress { ...@@ -35,6 +35,7 @@ class AscendDeviceAddress : public DeviceAddress {
~AscendDeviceAddress() override; ~AscendDeviceAddress() override;
bool SyncDeviceToHost(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const override; bool SyncDeviceToHost(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const override;
bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, const void *host_ptr) const override; bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, const void *host_ptr) const override;
DeviceAddressType DeviceType() const override { return DeviceAddressType::kAscend; }
#ifdef ENABLE_DUMP_E2E #ifdef ENABLE_DUMP_E2E
bool DumpMemToFile(bool dump_mode, const std::string &filepath, const std::string &host_fmt, bool DumpMemToFile(bool dump_mode, const std::string &filepath, const std::string &host_fmt,
const std::vector<int> &host_shape, TypeId host_type) const; const std::vector<int> &host_shape, TypeId host_type) const;
......
...@@ -259,6 +259,15 @@ bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { ...@@ -259,6 +259,15 @@ bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) {
return true; return true;
} }
bool AscendKernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
if (AnfAlgo::OutputAddrExist(kernel, index)) {
auto address = AnfAlgo::GetOutputAddr(kernel, index);
MS_EXCEPTION_IF_NULL(address);
return address->DeviceType() == DeviceAddressType::kAscend;
}
return false;
}
DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) { TypeId type_id) {
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id); return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id);
......
...@@ -45,6 +45,7 @@ class AscendKernelRuntime : public KernelRuntime { ...@@ -45,6 +45,7 @@ class AscendKernelRuntime : public KernelRuntime {
protected: protected:
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override; TypeId type_id) override;
bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index) override;
bool SyncStream() override; bool SyncStream() override;
private: private:
......
...@@ -34,6 +34,7 @@ class CPUDeviceAddress : public DeviceAddress { ...@@ -34,6 +34,7 @@ class CPUDeviceAddress : public DeviceAddress {
bool SyncDeviceToHost(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const override; bool SyncDeviceToHost(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const override;
bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, const void *host_ptr) const override; bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, const void *host_ptr) const override;
DeviceAddressType DeviceType() const override { return DeviceAddressType::kCPU; }
}; };
} // namespace cpu } // namespace cpu
} // namespace device } // namespace device
......
...@@ -48,6 +48,7 @@ class GPUMemoryManager; ...@@ -48,6 +48,7 @@ class GPUMemoryManager;
namespace mindspore { namespace mindspore {
namespace device { namespace device {
enum class DeviceAddressStatus { kInDevice, kInHost, kInDeviceToHost, kInHostToDevice }; enum class DeviceAddressStatus { kInDevice, kInHost, kInDeviceToHost, kInHostToDevice };
enum class DeviceAddressType { kUnknown, kAscend, kCPU, kGPU };
class DeviceAddress { class DeviceAddress {
public: public:
...@@ -64,6 +65,7 @@ class DeviceAddress { ...@@ -64,6 +65,7 @@ class DeviceAddress {
TypeId type_id() const { return type_id_; } TypeId type_id() const { return type_id_; }
virtual void set_status(DeviceAddressStatus status) {} virtual void set_status(DeviceAddressStatus status) {}
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; } virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; }
protected: protected:
const void *ptr() const { return ptr_; } const void *ptr() const { return ptr_; }
......
...@@ -35,6 +35,7 @@ class GPUDeviceAddress : public DeviceAddress { ...@@ -35,6 +35,7 @@ class GPUDeviceAddress : public DeviceAddress {
bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, const void *host_ptr) const override; bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, const void *host_ptr) const override;
void set_status(DeviceAddressStatus status) { status_ = status; } void set_status(DeviceAddressStatus status) { status_ = status; }
DeviceAddressStatus status() const { return status_; } DeviceAddressStatus status() const { return status_; }
DeviceAddressType DeviceType() const override { return DeviceAddressType::kGPU; }
private: private:
DeviceAddressStatus status_{DeviceAddressStatus::kInDevice}; DeviceAddressStatus status_{DeviceAddressStatus::kInDevice};
......
...@@ -102,6 +102,13 @@ bool KernelRuntime::RunTask(const session::KernelGraph *graph) { ...@@ -102,6 +102,13 @@ bool KernelRuntime::RunTask(const session::KernelGraph *graph) {
return false; return false;
} }
bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
if (AnfAlgo::OutputAddrExist(kernel, index)) {
return true;
}
return false;
}
size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) { size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (output_index >= AnfAlgo::GetOutputTensorNum(node)) { if (output_index >= AnfAlgo::GetOutputTensorNum(node)) {
...@@ -255,7 +262,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { ...@@ -255,7 +262,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
if (i < graph_valid_input.size() && !graph_valid_input[i]) { if (i < graph_valid_input.size() && !graph_valid_input[i]) {
continue; continue;
} }
if (AnfAlgo::OutputAddrExist(item, 0)) { if (NodeOutputDeviceAddressExist(item, 0)) {
continue; continue;
} }
auto output_size = AnfAlgo::GetOutputTensorNum(item); auto output_size = AnfAlgo::GetOutputTensorNum(item);
...@@ -431,7 +438,7 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in ...@@ -431,7 +438,7 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in
if ((kGetAllOuts != index) && (SizeToInt(i) != index)) { if ((kGetAllOuts != index) && (SizeToInt(i) != index)) {
continue; continue;
} }
if (AnfAlgo::OutputAddrExist(node, i)) { if (NodeOutputDeviceAddressExist(node, i)) {
MS_LOG(INFO) << "Already malloc index:" << i; MS_LOG(INFO) << "Already malloc index:" << i;
continue; continue;
} }
...@@ -493,7 +500,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { ...@@ -493,7 +500,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
for (auto &value_node : graph->graph_value_nodes()) { for (auto &value_node : graph->graph_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node); MS_EXCEPTION_IF_NULL(value_node);
if (AnfAlgo::OutputAddrExist(value_node, 0)) { if (NodeOutputDeviceAddressExist(value_node, 0)) {
MS_LOG(INFO) << "value_node[" << value_node->DebugString() << "] address already exist"; MS_LOG(INFO) << "value_node[" << value_node->DebugString() << "] address already exist";
continue; continue;
} }
......
...@@ -67,6 +67,7 @@ class KernelRuntime { ...@@ -67,6 +67,7 @@ class KernelRuntime {
protected: protected:
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) = 0; TypeId type_id) = 0;
virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index);
virtual bool SyncStream() = 0; virtual bool SyncStream() = 0;
void AssignStaticMemory(session::KernelGraph *graph); void AssignStaticMemory(session::KernelGraph *graph);
void AssignDynamicMemory(session::KernelGraph *graph); void AssignDynamicMemory(session::KernelGraph *graph);
......
...@@ -307,17 +307,27 @@ bool TaskEmitAction(const ResourcePtr &res) { ...@@ -307,17 +307,27 @@ bool TaskEmitAction(const ResourcePtr &res) {
} }
FuncGraphPtr func_graph = res->func_graph(); FuncGraphPtr func_graph = res->func_graph();
auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>(); auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
if (IsCtrlSink()) { if (IsCtrlSink()) {
res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph));
return true; return true;
} }
std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops; std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops;
if (bc_ptr->name() == kMsConvert) { if (bc_ptr->name() == kMsConvert) {
cut_list = compile::GetMsNonlinearOps(); cut_list = compile::GetMsNonlinearOps();
} }
std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list); std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (compile->ContainMixedTarget(func_graph)) {
bc_ptr->set_is_multi_graph_sink(false);
context_ptr->set_loop_sink_flag(false);
} else if (context_ptr->execution_mode() != kPynativeMode) {
std::string device_target = context_ptr->device_target();
if (device_target == kAscendDevice) {
bc_ptr->set_is_multi_graph_sink(true);
}
}
res->results()[kOutput] = compile->CompileAndLink(func_graph); res->results()[kOutput] = compile->CompileAndLink(func_graph);
return true; return true;
} }
......
...@@ -778,7 +778,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc ...@@ -778,7 +778,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
MS_EXCEPTION_IF_NULL(convert_fn); MS_EXCEPTION_IF_NULL(convert_fn);
// Convert CNodeList to LinConvertResult. // Convert CNodeList to LinConvertResult.
ConfigManager::GetInstance().set_iter_num(1); ConfigManager::GetInstance().set_iter_num(1);
auto runner = convert_fn({app_init}); auto runner = convert_fn({app_init}, "");
if (MsContext::GetInstance()->execution_mode() != kPynativeMode) { if (MsContext::GetInstance()->execution_mode() != kPynativeMode) {
backend->Link(runner.graph_id); backend->Link(runner.graph_id);
} }
......
...@@ -28,6 +28,23 @@ ...@@ -28,6 +28,23 @@
namespace mindspore { namespace mindspore {
namespace session { namespace session {
ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
if (!anf->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
}
auto valid_inputs = graph->MutableValidInputs();
MS_EXCEPTION_IF_NULL(valid_inputs);
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
ParameterPtr new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
TraceManager::EndTrace();
graph_inputs->push_back(new_parameter);
valid_inputs->push_back(valid_input);
return new_parameter;
}
GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
auto graph_id = graph_sum_; auto graph_id = graph_sum_;
auto graph = ConstructKernelGraph(lst, outputs); auto graph = ConstructKernelGraph(lst, outputs);
......
...@@ -35,6 +35,9 @@ class CPUSession : public SessionBasic { ...@@ -35,6 +35,9 @@ class CPUSession : public SessionBasic {
GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
protected:
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) override;
private: private:
void SetKernelInfo(const KernelGraph *kernel_graph); void SetKernelInfo(const KernelGraph *kernel_graph);
void BuildKernel(const KernelGraph *kernel_graph); void BuildKernel(const KernelGraph *kernel_graph);
......
...@@ -482,7 +482,13 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de ...@@ -482,7 +482,13 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
depend_nodes = GetOutputNodes(depend_node); depend_nodes = GetOutputNodes(depend_node);
} }
for (auto &first_node : prior_nodes) { for (auto &first_node : prior_nodes) {
if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) {
continue;
}
for (auto &second_node : depend_nodes) { for (auto &second_node : depend_nodes) {
if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) {
continue;
}
MS_EXCEPTION_IF_NULL(first_node); MS_EXCEPTION_IF_NULL(first_node);
MS_EXCEPTION_IF_NULL(second_node); MS_EXCEPTION_IF_NULL(second_node);
MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString(); MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString();
......
...@@ -311,7 +311,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf ...@@ -311,7 +311,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
if (python_paras_ == nullptr) { if (python_paras_ == nullptr) {
python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>(); python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>();
} }
if (python_paras_->find(m_tensor) != python_paras_->end() && GetGraphIdByNode(anf) == kInvalidGraphId) { if (python_paras_->find(m_tensor) != python_paras_->end()) {
new_parameter = (*python_paras_)[m_tensor]; new_parameter = (*python_paras_)[m_tensor];
} else { } else {
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
......
...@@ -114,7 +114,7 @@ class SessionBasic { ...@@ -114,7 +114,7 @@ class SessionBasic {
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref);
// create a new kernel graph and update the graph sum // create a new kernel graph and update the graph sum
KernelGraphPtr NewKernelGraph(); KernelGraphPtr NewKernelGraph();
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph);
ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph);
ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph);
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph);
......
...@@ -92,7 +92,7 @@ class MsContext { ...@@ -92,7 +92,7 @@ class MsContext {
bool ir_fusion_flag() const { return ir_fusion_flag_; } bool ir_fusion_flag() const { return ir_fusion_flag_; }
bool loop_sink_flag() const { return enable_loop_sink_; } bool loop_sink_flag() const { return enable_loop_sink_; }
void set_loop_sink_flag(bool enable_loop_sink) { enable_loop_sink_ = enable_loop_sink; }
void set_enable_mem_reuse(bool enable_mem_reuse) { enable_mem_reuse_ = enable_mem_reuse; } void set_enable_mem_reuse(bool enable_mem_reuse) { enable_mem_reuse_ = enable_mem_reuse; }
bool enable_mem_reuse() const { return enable_mem_reuse_; } bool enable_mem_reuse() const { return enable_mem_reuse_; }
......
...@@ -39,14 +39,14 @@ LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) { ...@@ -39,14 +39,14 @@ LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) {
multi_result_.inputs = g->parameters(); multi_result_.inputs = g->parameters();
final_output_ = NewValueNode("fake_output"); final_output_ = NewValueNode("fake_output");
multi_result_.outputs = {final_output_}; multi_result_.outputs = {final_output_};
GraphId final_g = sess_->GetFinalRunGraph(); GraphId final_g = target_sess_->GetFinalRunGraph();
multi_result_.run = std::make_shared<RunFunc>( multi_result_.run = std::make_shared<RunFunc>(
[final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args); }); [final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args, ""); });
return multi_result_; return multi_result_;
} }
LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) { LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::string &target) {
MS_LOG(DEBUG) << "MsConvert"; MS_LOG(DEBUG) << "MsConvert";
MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
auto cached = g_ConvertCache.find(lst); auto cached = g_ConvertCache.find(lst);
...@@ -64,17 +64,24 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) { ...@@ -64,17 +64,24 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) {
result.inputs = inputs; result.inputs = inputs;
result.outputs = outputs; result.outputs = outputs;
result.graph_id = kInvalidGraphId; result.graph_id = kInvalidGraphId;
auto graph_id = sess_->CompileGraph(lst, outputs); GraphId graph_id = kInvalidGraphId;
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { if (target == kCPUDevice) {
sess_->BuildGraph(graph_id); graph_id = cpu_sess_->CompileGraph(lst, outputs);
} else {
graph_id = target_sess_->CompileGraph(lst, outputs);
} }
if (MsContext::GetInstance()->precompile_only()) { if (MsContext::GetInstance()->precompile_only()) {
MS_LOG(INFO) << "PrecompileOnly, stop run graph"; MS_LOG(INFO) << "PrecompileOnly, stop run graph";
return result; return result;
} }
if (target == kCPUDevice) {
cpu_sess_->BuildGraph(graph_id);
} else if (!is_multi_graph_sink_) {
target_sess_->BuildGraph(graph_id);
}
result.run = std::make_shared<RunFunc>( result.run = std::make_shared<RunFunc>(
[graph_id, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args); }); [graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); });
MS_EXCEPTION_IF_NULL(result.run); MS_EXCEPTION_IF_NULL(result.run);
result.simu_run = std::make_shared<RunFunc>( result.simu_run = std::make_shared<RunFunc>(
...@@ -92,7 +99,7 @@ void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) { ...@@ -92,7 +99,7 @@ void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) {
GraphId cond_g = kInvalidGraphId; GraphId cond_g = kInvalidGraphId;
if (utils::isa<AnfNodePtr>(c)) { if (utils::isa<AnfNodePtr>(c)) {
cond_g = sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(c)); cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(c));
} else { } else {
MS_LOG(EXCEPTION) << "cond not a anf node:" << c.ToString(); MS_LOG(EXCEPTION) << "cond not a anf node:" << c.ToString();
} }
...@@ -116,7 +123,7 @@ void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) { ...@@ -116,7 +123,7 @@ void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) {
MS_LOG(DEBUG) << "invoke set active:" << active_g; MS_LOG(DEBUG) << "invoke set active:" << active_g;
} }
MS_LOG(DEBUG) << "switch set active:" << active_g << ", " << cond_g; MS_LOG(DEBUG) << "switch set active:" << active_g << ", " << cond_g;
sess_->SetActive(active_g, cond_g); target_sess_->SetActive(active_g, cond_g);
} }
void MsBackend::SetSwitchGraph() { void MsBackend::SetSwitchGraph() {
...@@ -135,12 +142,12 @@ void MsBackend::SetSwitchGraph() { ...@@ -135,12 +142,12 @@ void MsBackend::SetSwitchGraph() {
} }
GraphId cond_g = kInvalidGraphId; GraphId cond_g = kInvalidGraphId;
if (utils::isa<AnfNodePtr>(curr_switch_)) { if (utils::isa<AnfNodePtr>(curr_switch_)) {
cond_g = sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(curr_switch_)); cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(curr_switch_));
} else { } else {
MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString(); MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString();
} }
MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g; MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g;
sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_)); target_sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_));
} }
is_switch_call_ = false; is_switch_call_ = false;
MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_; MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_;
...@@ -202,7 +209,7 @@ void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef ...@@ -202,7 +209,7 @@ void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef
old_args[i] = args[it->second]; old_args[i] = args[it->second];
} }
} }
sess_->SetChildGraphInput(graph, old_args); target_sess_->SetChildGraphInput(graph, old_args);
} }
graph_inputs_.erase(c); graph_inputs_.erase(c);
} }
...@@ -211,7 +218,7 @@ void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef ...@@ -211,7 +218,7 @@ void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef
VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) {
MS_LOG(DEBUG) << "set graph input:" << g; MS_LOG(DEBUG) << "set graph input:" << g;
// switch maybe twice // switch maybe twice
sess_->SetChildGraphInput(g, args); target_sess_->SetChildGraphInput(g, args);
if (is_switch_call_) { if (is_switch_call_) {
if (!curr_switch_.is_null()) { if (!curr_switch_.is_null()) {
...@@ -236,7 +243,7 @@ VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { ...@@ -236,7 +243,7 @@ VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) {
return VectorRef(outputs); return VectorRef(outputs);
} }
VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) { VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g; MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g;
// Run graph // Run graph
std::vector<tensor::TensorPtr> inputs; std::vector<tensor::TensorPtr> inputs;
...@@ -271,7 +278,12 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) { ...@@ -271,7 +278,12 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) {
VectorRef outputs; VectorRef outputs;
// call ms rungraph (graphId, input ,output) // call ms rungraph (graphId, input ,output)
sess_->RunGraph(g, inputs, &outputs); if (target == kCPUDevice) {
cpu_sess_->RunGraph(g, inputs, &outputs);
} else {
target_sess_->RunGraph(g, inputs, &outputs);
}
MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size(); MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size();
return outputs; return outputs;
} }
...@@ -300,17 +312,17 @@ void MsBackend::SimulateRun(FinalVMPtr rt, FuncGraphPtr root) { ...@@ -300,17 +312,17 @@ void MsBackend::SimulateRun(FinalVMPtr rt, FuncGraphPtr root) {
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args), (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args),
[](const AnfNodePtr &v) { return v; }); [](const AnfNodePtr &v) { return v; });
MS_LOG(DEBUG) << "Simulate start"; MS_LOG(DEBUG) << "Simulate start";
(void)sess_->SetFinalGraphInput(parameters); (void)target_sess_->SetFinalGraphInput(parameters);
BaseRef output = rt->Eval(VectorRef(args)); BaseRef output = rt->Eval(VectorRef(args));
sess_->SetFinalGraphOutput(output); target_sess_->SetFinalGraphOutput(output);
MS_LOG(DEBUG) << "Simulate Eval end"; MS_LOG(DEBUG) << "Simulate Eval end";
} }
void MsBackend::Link(GraphId graph_id) { void MsBackend::Link(GraphId graph_id) {
if (graph_id == kInvalidGraphId) { if (graph_id == kInvalidGraphId) {
graph_id = sess_->GetFinalRunGraph(); graph_id = target_sess_->GetFinalRunGraph();
} }
sess_->BuildGraph(graph_id); target_sess_->BuildGraph(graph_id);
} }
Backend::Backend(const std::string &name) : name_(name) { Backend::Backend(const std::string &name) : name_(name) {
...@@ -322,16 +334,26 @@ Backend::Backend(const std::string &name) : name_(name) { ...@@ -322,16 +334,26 @@ Backend::Backend(const std::string &name) : name_(name) {
} }
MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) { MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1); convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2);
sess_ = session::SessionFactory::Get().Create(target); target_sess_ = session::SessionFactory::Get().Create(target);
if (sess_ == nullptr) { if (target_sess_ == nullptr) {
MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available."; MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
} }
sess_->Init(device_id); target_sess_->Init(device_id);
sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
if (target == kCPUDevice) {
cpu_sess_ = target_sess_;
} else {
cpu_sess_ = session::SessionFactory::Get().Create(kCPUDevice);
if (cpu_sess_ == nullptr) {
MS_LOG(EXCEPTION) << "Create cpu session failed with target " << target << ".";
}
cpu_sess_->Init(0);
cpu_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
}
} }
GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return sess_->CompileGraph(fg); } GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return target_sess_->CompileGraph(fg); }
VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); } VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); }
......
...@@ -91,8 +91,8 @@ class MsBackend : public Backend { ...@@ -91,8 +91,8 @@ class MsBackend : public Backend {
MsBackend(const std::string &name, const std::string &target, uint32_t device_id); MsBackend(const std::string &name, const std::string &target, uint32_t device_id);
~MsBackend() override = default; ~MsBackend() override = default;
LinConvertResult MsConvert(const AnfNodePtrList &lst); LinConvertResult MsConvert(const AnfNodePtrList &lst, const std::string &target = "");
VectorRef MsRunGraph(const GraphId &g, const VectorRef &args); VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = "");
VectorRef MsSimuRunGraph(const GraphId &g, const VectorRef &args); VectorRef MsSimuRunGraph(const GraphId &g, const VectorRef &args);
void SimulateRun(FinalVMPtr rt, FuncGraphPtr root) override; void SimulateRun(FinalVMPtr rt, FuncGraphPtr root) override;
...@@ -109,7 +109,8 @@ class MsBackend : public Backend { ...@@ -109,7 +109,8 @@ class MsBackend : public Backend {
VectorRef RunGraph(GraphId graph_id, const VectorRef &args); VectorRef RunGraph(GraphId graph_id, const VectorRef &args);
private: private:
session::SessionPtr sess_; session::SessionPtr target_sess_;
session::SessionPtr cpu_sess_;
std::unordered_map<BaseRef, CondGraph, BaseRefHash> simu_cond_map_; std::unordered_map<BaseRef, CondGraph, BaseRefHash> simu_cond_map_;
std::unordered_map<GraphId, LinConvertResult> graph_id_map_; std::unordered_map<GraphId, LinConvertResult> graph_id_map_;
std::unordered_map<BaseRef, std::list<std::pair<GraphId, VectorRef>>, BaseRefHash> graph_inputs_; std::unordered_map<BaseRef, std::list<std::pair<GraphId, VectorRef>>, BaseRefHash> graph_inputs_;
......
...@@ -148,7 +148,7 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr ...@@ -148,7 +148,7 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
// This implementation will convert the nodes into a subgraph // This implementation will convert the nodes into a subgraph
// that will run using the MsVM. // that will run using the MsVM.
template <typename T> template <typename T>
LinConvertResult Convert(const AnfNodePtrList &lst) { LinConvertResult Convert(const AnfNodePtrList &lst, const std::string &) {
auto cached = g_ConvertCache.find(lst); auto cached = g_ConvertCache.find(lst);
if (cached != g_ConvertCache.end()) { if (cached != g_ConvertCache.end()) {
return cached->second; return cached->second;
......
...@@ -43,7 +43,7 @@ struct LinConvertResult { ...@@ -43,7 +43,7 @@ struct LinConvertResult {
uint32_t graph_id; uint32_t graph_id;
}; };
using LinkFuncType = std::function<LinConvertResult(const AnfNodePtrList &)>; using LinkFuncType = std::function<LinConvertResult(const AnfNodePtrList &, const std::string &)>;
using ConvertCache = std::unordered_map<BaseRef, LinConvertResult, BaseRefHash>; using ConvertCache = std::unordered_map<BaseRef, LinConvertResult, BaseRefHash>;
extern LinkFuncType MsVmConvert; extern LinkFuncType MsVmConvert;
extern LinkFuncType GeVmConvert; extern LinkFuncType GeVmConvert;
......
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include <algorithm> #include <algorithm>
#include <map> #include <map>
#include <queue>
#include <set>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -47,6 +49,86 @@ const std::vector<PrimitivePtr> &GetMsNonlinearOps() { ...@@ -47,6 +49,86 @@ const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
return ms_nonlinear_ops; return ms_nonlinear_ops;
} }
namespace {
std::string GetCNodeTarget(const AnfNodePtr &node) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->device_target();
if (!node->isa<CNode>()) {
return default_target;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto attr_input = cnode->input(kAnfPrimitiveIndex);
if (attr_input == nullptr) {
return default_target;
}
auto value_node = attr_input->cast<ValueNodePtr>();
if (value_node == nullptr) {
return default_target;
}
auto value = value_node->value();
if (value == nullptr) {
return default_target;
}
if (!value->isa<Primitive>()) {
return default_target;
}
auto primitive = value->cast<PrimitivePtr>();
ValuePtr att_target = primitive->GetAttr("target");
if (att_target != nullptr) {
std::string target = GetValue<std::string>(att_target);
return target;
}
return default_target;
}
bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string last_target = context_ptr->device_target();
for (auto &node : nodes) {
if (node->isa<CNode>()) {
std::string cur_target = GetCNodeTarget(node);
if (last_target != cur_target) {
return true;
}
last_target = cur_target;
}
}
return false;
}
void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref) {
std::queue<AnfNodePtr> queue;
queue.push(graph->get_return());
std::set<AnfNodePtr> visited;
while (!queue.empty()) {
auto &node = queue.front();
queue.pop();
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (auto &input : cnode->inputs()) {
auto iter = nodes_ref->find(input);
if (iter != nodes_ref->end()) {
iter->second++;
} else {
(void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(input, 1));
}
if (visited.find(input) != visited.end()) {
continue;
}
visited.insert(input);
queue.push(input);
}
}
}
} // namespace
CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list)
: backend_(backend), cut_list_(cut_list) { : backend_(backend), cut_list_(cut_list) {
MS_EXCEPTION_IF_NULL(backend_); MS_EXCEPTION_IF_NULL(backend_);
...@@ -98,12 +180,67 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) { ...@@ -98,12 +180,67 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
return false; return false;
} }
std::vector<AnfNodePtr> CompileGraph::SplitSort(const FuncGraphPtr &graph) {
std::vector<AnfNodePtr> result;
std::queue<AnfNodePtr> queue;
std::queue<AnfNodePtr> next_queue;
std::map<AnfNodePtr, size_t> nodes_ref;
CalcNodeRefCount(graph, &nodes_ref);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string queue_target = context_ptr->device_target();
std::string next_target = "";
queue.push(graph->get_return());
while (!queue.empty() || !next_queue.empty()) {
if (queue.empty()) {
queue.swap(next_queue);
queue_target = next_target;
}
auto &node = queue.front();
queue.pop();
MS_EXCEPTION_IF_NULL(node);
result.emplace_back(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (auto &input : cnode->inputs()) {
auto iter = nodes_ref.find(input);
if (iter != nodes_ref.end()) {
iter->second--;
if (iter->second != 0) {
continue;
}
}
if (!input->isa<CNode>()) {
queue.push(input);
continue;
}
std::string input_target = GetCNodeTarget(input);
if (input_target == queue_target) {
queue.push(input);
} else if (next_queue.empty() || input_target == next_target) {
next_queue.push(input);
next_target = input_target;
} else {
MS_LOG(EXCEPTION) << "only support two different target";
}
}
}
std::reverse(result.begin(), result.end());
return result;
}
VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
VectorRef splits; VectorRef splits;
VectorRef split; VectorRef split;
std::vector<AnfNodePtr> nodes = TopoSort(graph->get_return()); auto nodes = TopoSort(graph->get_return());
if (ContainMultiTarget(nodes)) {
nodes = SplitSort(graph);
}
std::string last_target;
MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
for (auto &node : nodes) { for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
...@@ -114,7 +251,13 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { ...@@ -114,7 +251,13 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
} }
splits.push_back(node); splits.push_back(node);
split.clear(); split.clear();
} else if (!(node->isa<ValueNode>() || node->isa<Parameter>())) { } else if (node->isa<CNode>()) {
std::string cur_target = GetCNodeTarget(node);
if (cur_target != last_target && !last_target.empty() && split.size() != 0) {
splits.push_back(split);
split.clear();
}
last_target = cur_target;
split.push_back(node); split.push_back(node);
MS_LOG(DEBUG) << "Insert node:" << node->DebugString(10) << ", size:" << split.size(); MS_LOG(DEBUG) << "Insert node:" << node->DebugString(10) << ", size:" << split.size();
} }
...@@ -200,14 +343,14 @@ void CompileGraph::PushParameters(const FuncGraphPtr &graph) { ...@@ -200,14 +343,14 @@ void CompileGraph::PushParameters(const FuncGraphPtr &graph) {
} }
} }
int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list) { int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list, const std::string &target) {
MS_LOG(DEBUG) << "LinConvert start"; MS_LOG(DEBUG) << "LinConvert start";
LinConvertResult result; LinConvertResult result;
if (backend_->simu_flag()) { if (backend_->simu_flag()) {
result = backend_->GetMultiGraphRun(graph); result = backend_->GetMultiGraphRun(graph);
} else { } else {
result = lin_convert_(node_list); result = lin_convert_(node_list, target);
} }
if (result.run == nullptr) { if (result.run == nullptr) {
...@@ -316,7 +459,12 @@ bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) { ...@@ -316,7 +459,12 @@ bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) {
auto vec_ref = utils::cast<VectorRef>(split); auto vec_ref = utils::cast<VectorRef>(split);
(void)std::transform(vec_ref.begin(), vec_ref.end(), std::back_inserter(args), (void)std::transform(vec_ref.begin(), vec_ref.end(), std::back_inserter(args),
[](const BaseRef &v) { return utils::cast<AnfNodePtr>(v); }); [](const BaseRef &v) { return utils::cast<AnfNodePtr>(v); });
ret = LinConvert(graph, args); if (args.size() > 0) {
std::string cur_target = GetCNodeTarget(args[0]);
ret = LinConvert(graph, args, cur_target);
} else {
ret = LinConvert(graph, args);
}
MS_LOG(DEBUG) << "End a extern LinConvert"; MS_LOG(DEBUG) << "End a extern LinConvert";
if (ret == RET_FAILED) { if (ret == RET_FAILED) {
return false; return false;
...@@ -637,6 +785,19 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) { ...@@ -637,6 +785,19 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
return rt; return rt;
} }
bool CompileGraphs::ContainMixedTarget(const FuncGraphPtr &graph) {
auto graph_manager = graph->manager();
MS_EXCEPTION_IF_NULL(graph_manager);
FuncGraphSet graphs = graph_manager->func_graphs();
for (auto &g : graphs) {
auto nodes = TopoSort(g->get_return());
if (ContainMultiTarget(nodes)) {
return true;
}
}
return false;
}
BackendPtr CreateBackend() { BackendPtr CreateBackend() {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
......
...@@ -79,8 +79,9 @@ class CompileGraph { ...@@ -79,8 +79,9 @@ class CompileGraph {
private: private:
void PushParameters(const FuncGraphPtr &func_graph); void PushParameters(const FuncGraphPtr &func_graph);
std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph);
bool SplitGraph(const FuncGraphPtr &func_graph); bool SplitGraph(const FuncGraphPtr &func_graph);
int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list); int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = "");
int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node); int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);
int AddCall(const FuncGraphPtr &graph, const CNodePtr &node); int AddCall(const FuncGraphPtr &graph, const CNodePtr &node);
void AddSinkSwitch(const CNodePtr &node); void AddSinkSwitch(const CNodePtr &node);
...@@ -124,6 +125,7 @@ class CompileGraphs { ...@@ -124,6 +125,7 @@ class CompileGraphs {
void Compile(const FuncGraphPtr &func_graph); void Compile(const FuncGraphPtr &func_graph);
FinalVMPtr Link(const FuncGraphPtr &func_graph); FinalVMPtr Link(const FuncGraphPtr &func_graph);
FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph); FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph);
bool ContainMixedTarget(const FuncGraphPtr &graph);
private: private:
InstSet insts_; InstSet insts_;
......
...@@ -65,7 +65,7 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) { ...@@ -65,7 +65,7 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) {
for (auto &item : utils::cast<VectorRef>(todos[0])) { for (auto &item : utils::cast<VectorRef>(todos[0])) {
anf_list.push_back(utils::cast<AnfNodePtr>(item)); anf_list.push_back(utils::cast<AnfNodePtr>(item));
} }
auto convertResult = MsVmConvert(anf_list); auto convertResult = MsVmConvert(anf_list, "");
auto runResult = (*(convertResult.run))(args); auto runResult = (*(convertResult.run))(args);
ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 3.0); ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 3.0);
} }
...@@ -89,7 +89,7 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) { ...@@ -89,7 +89,7 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) {
for (auto &item : utils::cast<VectorRef>(todos[0])) { for (auto &item : utils::cast<VectorRef>(todos[0])) {
anf_list.push_back(utils::cast<AnfNodePtr>(item)); anf_list.push_back(utils::cast<AnfNodePtr>(item));
} }
auto convertResult = MsVmConvert(anf_list); auto convertResult = MsVmConvert(anf_list, "");
auto runResult = (*(convertResult.run))(args); auto runResult = (*(convertResult.run))(args);
ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 2.0); ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 2.0);
} }
...@@ -113,7 +113,7 @@ TEST_F(TestCompileSegmentRunner, test_if) { ...@@ -113,7 +113,7 @@ TEST_F(TestCompileSegmentRunner, test_if) {
for (auto &item : utils::cast<VectorRef>(todos[0])) { for (auto &item : utils::cast<VectorRef>(todos[0])) {
anf_list.push_back(utils::cast<AnfNodePtr>(item)); anf_list.push_back(utils::cast<AnfNodePtr>(item));
} }
auto convertResult = MsVmConvert(anf_list); auto convertResult = MsVmConvert(anf_list, "");
auto runResult = (*(convertResult.run))(args); auto runResult = (*(convertResult.run))(args);
auto result = py::cast<bool>(BaseRefToPyData(runResult[0])); auto result = py::cast<bool>(BaseRefToPyData(runResult[0]));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册