提交 4f08d35c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1415 support mix target

Merge pull request !1415 from kisnwang/support-mix-target
......@@ -35,6 +35,7 @@ class AscendDeviceAddress : public DeviceAddress {
~AscendDeviceAddress() override;
bool SyncDeviceToHost(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const override;
bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, const void *host_ptr) const override;
DeviceAddressType DeviceType() const override { return DeviceAddressType::kAscend; }
#ifdef ENABLE_DUMP_E2E
bool DumpMemToFile(bool dump_mode, const std::string &filepath, const std::string &host_fmt,
const std::vector<int> &host_shape, TypeId host_type) const;
......
......@@ -259,6 +259,15 @@ bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) {
return true;
}
bool AscendKernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
if (AnfAlgo::OutputAddrExist(kernel, index)) {
auto address = AnfAlgo::GetOutputAddr(kernel, index);
MS_EXCEPTION_IF_NULL(address);
return address->DeviceType() == DeviceAddressType::kAscend;
}
return false;
}
DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) {
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id);
......
......@@ -45,6 +45,7 @@ class AscendKernelRuntime : public KernelRuntime {
protected:
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override;
bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index) override;
bool SyncStream() override;
private:
......
......@@ -34,6 +34,7 @@ class CPUDeviceAddress : public DeviceAddress {
bool SyncDeviceToHost(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const override;
bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, const void *host_ptr) const override;
DeviceAddressType DeviceType() const override { return DeviceAddressType::kCPU; }
};
} // namespace cpu
} // namespace device
......
......@@ -48,6 +48,7 @@ class GPUMemoryManager;
namespace mindspore {
namespace device {
enum class DeviceAddressStatus { kInDevice, kInHost, kInDeviceToHost, kInHostToDevice };
enum class DeviceAddressType { kUnknown, kAscend, kCPU, kGPU };
class DeviceAddress {
public:
......@@ -64,6 +65,7 @@ class DeviceAddress {
TypeId type_id() const { return type_id_; }
virtual void set_status(DeviceAddressStatus status) {}
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; }
protected:
const void *ptr() const { return ptr_; }
......
......@@ -35,6 +35,7 @@ class GPUDeviceAddress : public DeviceAddress {
bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, const void *host_ptr) const override;
void set_status(DeviceAddressStatus status) { status_ = status; }
DeviceAddressStatus status() const { return status_; }
DeviceAddressType DeviceType() const override { return DeviceAddressType::kGPU; }
private:
DeviceAddressStatus status_{DeviceAddressStatus::kInDevice};
......
......@@ -102,6 +102,13 @@ bool KernelRuntime::RunTask(const session::KernelGraph *graph) {
return false;
}
bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
if (AnfAlgo::OutputAddrExist(kernel, index)) {
return true;
}
return false;
}
size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) {
MS_EXCEPTION_IF_NULL(node);
if (output_index >= AnfAlgo::GetOutputTensorNum(node)) {
......@@ -255,7 +262,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
if (i < graph_valid_input.size() && !graph_valid_input[i]) {
continue;
}
if (AnfAlgo::OutputAddrExist(item, 0)) {
if (NodeOutputDeviceAddressExist(item, 0)) {
continue;
}
auto output_size = AnfAlgo::GetOutputTensorNum(item);
......@@ -431,7 +438,7 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in
if ((kGetAllOuts != index) && (SizeToInt(i) != index)) {
continue;
}
if (AnfAlgo::OutputAddrExist(node, i)) {
if (NodeOutputDeviceAddressExist(node, i)) {
MS_LOG(INFO) << "Already malloc index:" << i;
continue;
}
......@@ -493,7 +500,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(ms_context);
for (auto &value_node : graph->graph_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node);
if (AnfAlgo::OutputAddrExist(value_node, 0)) {
if (NodeOutputDeviceAddressExist(value_node, 0)) {
MS_LOG(INFO) << "value_node[" << value_node->DebugString() << "] address already exist";
continue;
}
......
......@@ -67,6 +67,7 @@ class KernelRuntime {
protected:
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) = 0;
virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index);
virtual bool SyncStream() = 0;
void AssignStaticMemory(session::KernelGraph *graph);
void AssignDynamicMemory(session::KernelGraph *graph);
......
......@@ -307,17 +307,27 @@ bool TaskEmitAction(const ResourcePtr &res) {
}
FuncGraphPtr func_graph = res->func_graph();
auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
if (IsCtrlSink()) {
res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph));
return true;
}
std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops;
if (bc_ptr->name() == kMsConvert) {
cut_list = compile::GetMsNonlinearOps();
}
std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (compile->ContainMixedTarget(func_graph)) {
bc_ptr->set_is_multi_graph_sink(false);
context_ptr->set_loop_sink_flag(false);
} else if (context_ptr->execution_mode() != kPynativeMode) {
std::string device_target = context_ptr->device_target();
if (device_target == kAscendDevice) {
bc_ptr->set_is_multi_graph_sink(true);
}
}
res->results()[kOutput] = compile->CompileAndLink(func_graph);
return true;
}
......
......@@ -778,7 +778,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
MS_EXCEPTION_IF_NULL(convert_fn);
// Convert CNodeList to LinConvertResult.
ConfigManager::GetInstance().set_iter_num(1);
auto runner = convert_fn({app_init});
auto runner = convert_fn({app_init}, "");
if (MsContext::GetInstance()->execution_mode() != kPynativeMode) {
backend->Link(runner.graph_id);
}
......
......@@ -28,6 +28,23 @@
namespace mindspore {
namespace session {
ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
if (!anf->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
}
auto valid_inputs = graph->MutableValidInputs();
MS_EXCEPTION_IF_NULL(valid_inputs);
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
ParameterPtr new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
TraceManager::EndTrace();
graph_inputs->push_back(new_parameter);
valid_inputs->push_back(valid_input);
return new_parameter;
}
GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
auto graph_id = graph_sum_;
auto graph = ConstructKernelGraph(lst, outputs);
......
......@@ -35,6 +35,9 @@ class CPUSession : public SessionBasic {
GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
protected:
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) override;
private:
void SetKernelInfo(const KernelGraph *kernel_graph);
void BuildKernel(const KernelGraph *kernel_graph);
......
......@@ -482,7 +482,13 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
depend_nodes = GetOutputNodes(depend_node);
}
for (auto &first_node : prior_nodes) {
if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) {
continue;
}
for (auto &second_node : depend_nodes) {
if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) {
continue;
}
MS_EXCEPTION_IF_NULL(first_node);
MS_EXCEPTION_IF_NULL(second_node);
MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString();
......
......@@ -311,7 +311,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
if (python_paras_ == nullptr) {
python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>();
}
if (python_paras_->find(m_tensor) != python_paras_->end() && GetGraphIdByNode(anf) == kInvalidGraphId) {
if (python_paras_->find(m_tensor) != python_paras_->end()) {
new_parameter = (*python_paras_)[m_tensor];
} else {
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
......
......@@ -114,7 +114,7 @@ class SessionBasic {
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref);
// create a new kernel graph and update the graph sum
KernelGraphPtr NewKernelGraph();
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph);
virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph);
ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph);
ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph);
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph);
......
......@@ -92,7 +92,7 @@ class MsContext {
bool ir_fusion_flag() const { return ir_fusion_flag_; }
bool loop_sink_flag() const { return enable_loop_sink_; }
void set_loop_sink_flag(bool enable_loop_sink) { enable_loop_sink_ = enable_loop_sink; }
void set_enable_mem_reuse(bool enable_mem_reuse) { enable_mem_reuse_ = enable_mem_reuse; }
bool enable_mem_reuse() const { return enable_mem_reuse_; }
......
......@@ -39,14 +39,14 @@ LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) {
multi_result_.inputs = g->parameters();
final_output_ = NewValueNode("fake_output");
multi_result_.outputs = {final_output_};
GraphId final_g = sess_->GetFinalRunGraph();
GraphId final_g = target_sess_->GetFinalRunGraph();
multi_result_.run = std::make_shared<RunFunc>(
[final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args); });
[final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args, ""); });
return multi_result_;
}
LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) {
LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::string &target) {
MS_LOG(DEBUG) << "MsConvert";
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
auto cached = g_ConvertCache.find(lst);
......@@ -64,17 +64,24 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) {
result.inputs = inputs;
result.outputs = outputs;
result.graph_id = kInvalidGraphId;
auto graph_id = sess_->CompileGraph(lst, outputs);
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
sess_->BuildGraph(graph_id);
GraphId graph_id = kInvalidGraphId;
if (target == kCPUDevice) {
graph_id = cpu_sess_->CompileGraph(lst, outputs);
} else {
graph_id = target_sess_->CompileGraph(lst, outputs);
}
if (MsContext::GetInstance()->precompile_only()) {
MS_LOG(INFO) << "PrecompileOnly, stop run graph";
return result;
}
if (target == kCPUDevice) {
cpu_sess_->BuildGraph(graph_id);
} else if (!is_multi_graph_sink_) {
target_sess_->BuildGraph(graph_id);
}
result.run = std::make_shared<RunFunc>(
[graph_id, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args); });
[graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); });
MS_EXCEPTION_IF_NULL(result.run);
result.simu_run = std::make_shared<RunFunc>(
......@@ -92,7 +99,7 @@ void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) {
GraphId cond_g = kInvalidGraphId;
if (utils::isa<AnfNodePtr>(c)) {
cond_g = sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(c));
cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(c));
} else {
MS_LOG(EXCEPTION) << "cond not a anf node:" << c.ToString();
}
......@@ -116,7 +123,7 @@ void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) {
MS_LOG(DEBUG) << "invoke set active:" << active_g;
}
MS_LOG(DEBUG) << "switch set active:" << active_g << ", " << cond_g;
sess_->SetActive(active_g, cond_g);
target_sess_->SetActive(active_g, cond_g);
}
void MsBackend::SetSwitchGraph() {
......@@ -135,12 +142,12 @@ void MsBackend::SetSwitchGraph() {
}
GraphId cond_g = kInvalidGraphId;
if (utils::isa<AnfNodePtr>(curr_switch_)) {
cond_g = sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(curr_switch_));
cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(curr_switch_));
} else {
MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString();
}
MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g;
sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_));
target_sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_));
}
is_switch_call_ = false;
MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_;
......@@ -202,7 +209,7 @@ void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef
old_args[i] = args[it->second];
}
}
sess_->SetChildGraphInput(graph, old_args);
target_sess_->SetChildGraphInput(graph, old_args);
}
graph_inputs_.erase(c);
}
......@@ -211,7 +218,7 @@ void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef
VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) {
MS_LOG(DEBUG) << "set graph input:" << g;
// switch maybe twice
sess_->SetChildGraphInput(g, args);
target_sess_->SetChildGraphInput(g, args);
if (is_switch_call_) {
if (!curr_switch_.is_null()) {
......@@ -236,7 +243,7 @@ VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) {
return VectorRef(outputs);
}
VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) {
VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g;
// Run graph
std::vector<tensor::TensorPtr> inputs;
......@@ -271,7 +278,12 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) {
VectorRef outputs;
// call ms rungraph (graphId, input ,output)
sess_->RunGraph(g, inputs, &outputs);
if (target == kCPUDevice) {
cpu_sess_->RunGraph(g, inputs, &outputs);
} else {
target_sess_->RunGraph(g, inputs, &outputs);
}
MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size();
return outputs;
}
......@@ -300,17 +312,17 @@ void MsBackend::SimulateRun(FinalVMPtr rt, FuncGraphPtr root) {
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args),
[](const AnfNodePtr &v) { return v; });
MS_LOG(DEBUG) << "Simulate start";
(void)sess_->SetFinalGraphInput(parameters);
(void)target_sess_->SetFinalGraphInput(parameters);
BaseRef output = rt->Eval(VectorRef(args));
sess_->SetFinalGraphOutput(output);
target_sess_->SetFinalGraphOutput(output);
MS_LOG(DEBUG) << "Simulate Eval end";
}
void MsBackend::Link(GraphId graph_id) {
if (graph_id == kInvalidGraphId) {
graph_id = sess_->GetFinalRunGraph();
graph_id = target_sess_->GetFinalRunGraph();
}
sess_->BuildGraph(graph_id);
target_sess_->BuildGraph(graph_id);
}
Backend::Backend(const std::string &name) : name_(name) {
......@@ -322,16 +334,26 @@ Backend::Backend(const std::string &name) : name_(name) {
}
MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1);
sess_ = session::SessionFactory::Get().Create(target);
if (sess_ == nullptr) {
convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2);
target_sess_ = session::SessionFactory::Get().Create(target);
if (target_sess_ == nullptr) {
MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
}
sess_->Init(device_id);
sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
target_sess_->Init(device_id);
target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
if (target == kCPUDevice) {
cpu_sess_ = target_sess_;
} else {
cpu_sess_ = session::SessionFactory::Get().Create(kCPUDevice);
if (cpu_sess_ == nullptr) {
MS_LOG(EXCEPTION) << "Create cpu session failed with target " << target << ".";
}
cpu_sess_->Init(0);
cpu_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
}
}
GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return sess_->CompileGraph(fg); }
GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return target_sess_->CompileGraph(fg); }
VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); }
......
......@@ -91,8 +91,8 @@ class MsBackend : public Backend {
MsBackend(const std::string &name, const std::string &target, uint32_t device_id);
~MsBackend() override = default;
LinConvertResult MsConvert(const AnfNodePtrList &lst);
VectorRef MsRunGraph(const GraphId &g, const VectorRef &args);
LinConvertResult MsConvert(const AnfNodePtrList &lst, const std::string &target = "");
VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = "");
VectorRef MsSimuRunGraph(const GraphId &g, const VectorRef &args);
void SimulateRun(FinalVMPtr rt, FuncGraphPtr root) override;
......@@ -109,7 +109,8 @@ class MsBackend : public Backend {
VectorRef RunGraph(GraphId graph_id, const VectorRef &args);
private:
session::SessionPtr sess_;
session::SessionPtr target_sess_;
session::SessionPtr cpu_sess_;
std::unordered_map<BaseRef, CondGraph, BaseRefHash> simu_cond_map_;
std::unordered_map<GraphId, LinConvertResult> graph_id_map_;
std::unordered_map<BaseRef, std::list<std::pair<GraphId, VectorRef>>, BaseRefHash> graph_inputs_;
......
......@@ -148,7 +148,7 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
// This implementation will convert the nodes into a subgraph
// that will run using the MsVM.
template <typename T>
LinConvertResult Convert(const AnfNodePtrList &lst) {
LinConvertResult Convert(const AnfNodePtrList &lst, const std::string &) {
auto cached = g_ConvertCache.find(lst);
if (cached != g_ConvertCache.end()) {
return cached->second;
......
......@@ -43,7 +43,7 @@ struct LinConvertResult {
uint32_t graph_id;
};
using LinkFuncType = std::function<LinConvertResult(const AnfNodePtrList &)>;
using LinkFuncType = std::function<LinConvertResult(const AnfNodePtrList &, const std::string &)>;
using ConvertCache = std::unordered_map<BaseRef, LinConvertResult, BaseRefHash>;
extern LinkFuncType MsVmConvert;
extern LinkFuncType GeVmConvert;
......
......@@ -20,6 +20,8 @@
#include <algorithm>
#include <map>
#include <queue>
#include <set>
#include <string>
#include <vector>
......@@ -47,6 +49,86 @@ const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
return ms_nonlinear_ops;
}
namespace {
std::string GetCNodeTarget(const AnfNodePtr &node) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->device_target();
if (!node->isa<CNode>()) {
return default_target;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto attr_input = cnode->input(kAnfPrimitiveIndex);
if (attr_input == nullptr) {
return default_target;
}
auto value_node = attr_input->cast<ValueNodePtr>();
if (value_node == nullptr) {
return default_target;
}
auto value = value_node->value();
if (value == nullptr) {
return default_target;
}
if (!value->isa<Primitive>()) {
return default_target;
}
auto primitive = value->cast<PrimitivePtr>();
ValuePtr att_target = primitive->GetAttr("target");
if (att_target != nullptr) {
std::string target = GetValue<std::string>(att_target);
return target;
}
return default_target;
}
bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string last_target = context_ptr->device_target();
for (auto &node : nodes) {
if (node->isa<CNode>()) {
std::string cur_target = GetCNodeTarget(node);
if (last_target != cur_target) {
return true;
}
last_target = cur_target;
}
}
return false;
}
void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref) {
std::queue<AnfNodePtr> queue;
queue.push(graph->get_return());
std::set<AnfNodePtr> visited;
while (!queue.empty()) {
auto &node = queue.front();
queue.pop();
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (auto &input : cnode->inputs()) {
auto iter = nodes_ref->find(input);
if (iter != nodes_ref->end()) {
iter->second++;
} else {
(void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(input, 1));
}
if (visited.find(input) != visited.end()) {
continue;
}
visited.insert(input);
queue.push(input);
}
}
}
} // namespace
CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list)
: backend_(backend), cut_list_(cut_list) {
MS_EXCEPTION_IF_NULL(backend_);
......@@ -98,12 +180,67 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
return false;
}
std::vector<AnfNodePtr> CompileGraph::SplitSort(const FuncGraphPtr &graph) {
std::vector<AnfNodePtr> result;
std::queue<AnfNodePtr> queue;
std::queue<AnfNodePtr> next_queue;
std::map<AnfNodePtr, size_t> nodes_ref;
CalcNodeRefCount(graph, &nodes_ref);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string queue_target = context_ptr->device_target();
std::string next_target = "";
queue.push(graph->get_return());
while (!queue.empty() || !next_queue.empty()) {
if (queue.empty()) {
queue.swap(next_queue);
queue_target = next_target;
}
auto &node = queue.front();
queue.pop();
MS_EXCEPTION_IF_NULL(node);
result.emplace_back(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (auto &input : cnode->inputs()) {
auto iter = nodes_ref.find(input);
if (iter != nodes_ref.end()) {
iter->second--;
if (iter->second != 0) {
continue;
}
}
if (!input->isa<CNode>()) {
queue.push(input);
continue;
}
std::string input_target = GetCNodeTarget(input);
if (input_target == queue_target) {
queue.push(input);
} else if (next_queue.empty() || input_target == next_target) {
next_queue.push(input);
next_target = input_target;
} else {
MS_LOG(EXCEPTION) << "only support two different target";
}
}
}
std::reverse(result.begin(), result.end());
return result;
}
VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
VectorRef splits;
VectorRef split;
std::vector<AnfNodePtr> nodes = TopoSort(graph->get_return());
auto nodes = TopoSort(graph->get_return());
if (ContainMultiTarget(nodes)) {
nodes = SplitSort(graph);
}
std::string last_target;
MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
......@@ -114,7 +251,13 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
}
splits.push_back(node);
split.clear();
} else if (!(node->isa<ValueNode>() || node->isa<Parameter>())) {
} else if (node->isa<CNode>()) {
std::string cur_target = GetCNodeTarget(node);
if (cur_target != last_target && !last_target.empty() && split.size() != 0) {
splits.push_back(split);
split.clear();
}
last_target = cur_target;
split.push_back(node);
MS_LOG(DEBUG) << "Insert node:" << node->DebugString(10) << ", size:" << split.size();
}
......@@ -200,14 +343,14 @@ void CompileGraph::PushParameters(const FuncGraphPtr &graph) {
}
}
int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list) {
int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list, const std::string &target) {
MS_LOG(DEBUG) << "LinConvert start";
LinConvertResult result;
if (backend_->simu_flag()) {
result = backend_->GetMultiGraphRun(graph);
} else {
result = lin_convert_(node_list);
result = lin_convert_(node_list, target);
}
if (result.run == nullptr) {
......@@ -316,7 +459,12 @@ bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) {
auto vec_ref = utils::cast<VectorRef>(split);
(void)std::transform(vec_ref.begin(), vec_ref.end(), std::back_inserter(args),
[](const BaseRef &v) { return utils::cast<AnfNodePtr>(v); });
ret = LinConvert(graph, args);
if (args.size() > 0) {
std::string cur_target = GetCNodeTarget(args[0]);
ret = LinConvert(graph, args, cur_target);
} else {
ret = LinConvert(graph, args);
}
MS_LOG(DEBUG) << "End a extern LinConvert";
if (ret == RET_FAILED) {
return false;
......@@ -637,6 +785,19 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
return rt;
}
bool CompileGraphs::ContainMixedTarget(const FuncGraphPtr &graph) {
auto graph_manager = graph->manager();
MS_EXCEPTION_IF_NULL(graph_manager);
FuncGraphSet graphs = graph_manager->func_graphs();
for (auto &g : graphs) {
auto nodes = TopoSort(g->get_return());
if (ContainMultiTarget(nodes)) {
return true;
}
}
return false;
}
BackendPtr CreateBackend() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
......
......@@ -79,8 +79,9 @@ class CompileGraph {
private:
void PushParameters(const FuncGraphPtr &func_graph);
std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph);
bool SplitGraph(const FuncGraphPtr &func_graph);
int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list);
int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = "");
int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);
int AddCall(const FuncGraphPtr &graph, const CNodePtr &node);
void AddSinkSwitch(const CNodePtr &node);
......@@ -124,6 +125,7 @@ class CompileGraphs {
void Compile(const FuncGraphPtr &func_graph);
FinalVMPtr Link(const FuncGraphPtr &func_graph);
FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph);
bool ContainMixedTarget(const FuncGraphPtr &graph);
private:
InstSet insts_;
......
......@@ -65,7 +65,7 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) {
for (auto &item : utils::cast<VectorRef>(todos[0])) {
anf_list.push_back(utils::cast<AnfNodePtr>(item));
}
auto convertResult = MsVmConvert(anf_list);
auto convertResult = MsVmConvert(anf_list, "");
auto runResult = (*(convertResult.run))(args);
ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 3.0);
}
......@@ -89,7 +89,7 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) {
for (auto &item : utils::cast<VectorRef>(todos[0])) {
anf_list.push_back(utils::cast<AnfNodePtr>(item));
}
auto convertResult = MsVmConvert(anf_list);
auto convertResult = MsVmConvert(anf_list, "");
auto runResult = (*(convertResult.run))(args);
ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 2.0);
}
......@@ -113,7 +113,7 @@ TEST_F(TestCompileSegmentRunner, test_if) {
for (auto &item : utils::cast<VectorRef>(todos[0])) {
anf_list.push_back(utils::cast<AnfNodePtr>(item));
}
auto convertResult = MsVmConvert(anf_list);
auto convertResult = MsVmConvert(anf_list, "");
auto runResult = (*(convertResult.run))(args);
auto result = py::cast<bool>(BaseRefToPyData(runResult[0]));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册