提交 c20cd122 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3110 Delete deprecated codes of ascend control flow

Merge pull request !3110 from zhoufeng/delete-deprecated-codes
......@@ -51,26 +51,16 @@ class AscendSession : public SessionBasic {
py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors) override;
// set parameters of final graph
GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &args) override;
// set output of final graph
void SetFinalGraphOutput(const BaseRef &output) override;
// insert switch and set the relative active ops
void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g, const AnfNodePtr &condition_output) override;
// set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter
void SetChildGraphInput(GraphId g, const VectorRef &args) override;
// get graph id in child graphs by ME front anf node pointer
GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override;
// get graph id of final graph
GraphId GetFinalRunGraph() const override { return final_graph_id_; }
// insert active to graph
void SetActive(GraphId, GraphId) override;
// compile child graph when session have multiple child graphs
void CompileChildGraph(const KernelGraphPtr &child_graph);
void RecurseGetSummaryNodes(KernelGraph *graph, std::map<std::string, std::pair<AnfNodePtr, int>> *summary);
void GetSummaryNodes(KernelGraph *graph);
private:
void RecurseSetSummaryNodes(KernelGraph *graph, std::map<std::string, std::pair<AnfNodePtr, int>> *summary);
void SetSummaryNodes(KernelGraph *graph) override;
void InitRuntimeResource();
void SelectKernel(const KernelGraph &kernel_graph) const;
void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const;
......@@ -92,63 +82,21 @@ class AscendSession : public SessionBasic {
void RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const;
void RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;
size_t SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index);
size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index);
size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index);
void SetFinalGraphOutput(const AnfNodePtr &node);
void SetFinalGraphOutput(const ValuePtr &value);
void SetFinalGraphOutput(const VectorRef &vec_output);
void SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims,
const NotNull<std::set<KernelGraphPtr> *> memo);
// split graphs with recurse from root graph
void SplitGraphs(NotNull<KernelGraphPtr> root_graph);
void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs);
void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
static void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs);
static void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph);
std::vector<AnfNodePtr> ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
const std::vector<CNodePtr> &list);
void RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo);
void RecurseSplitGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo);
AnfNodePtr BindNewCallToNewGraph(NotNull<KernelGraphPtr> graph, const std::vector<CNodePtr> &child_graph_list);
// merge execution order list of child graphs
void MergeGraphExecOrder();
// insert assion op to sync data bettween different graphs
void InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to);
// insert mutiple assigns to graph
void InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to);
// insert active op to graph
void InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream);
// get execute index of graph
size_t ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph);
// handle condition graph from vm
void InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id);
// insert depend to graph, used to attch control nodes to graph
void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node);
// insert depend to graph, used to attch control nodes to graph
void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node);
// set child graph parameter if front arg is a anf
void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx);
// set child graph parameter if front arg is a tensor
void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, size_t input_idx);
// update the execution order of all child graphs
void UpdateGraphOrder(GraphId to_graph);
// handle switch when merge
void MergeSwitchCompile();
// get graph order vector by graph id
std::vector<GraphId> &GetGraphOrder(GraphId final_graph_id);
const std::vector<GraphId> &GetGraphOrder(GraphId final_graph_id) const;
// get graph order type vector by graph id
std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id);
// copy output of if and else
void CopyOutputOfIf(GraphId false_graph_id);
const std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id) const;
// check if graph cache exist
bool GraphCacheExist(const GraphInfo &graph_info) const;
// insert all assign to child graph
void InsertAllAssigns();
// create fake output of final graph
AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output);
// sync intial tensors' data to device
void SyncInitialTenosrToDevice();
void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph);
......@@ -162,16 +110,10 @@ class AscendSession : public SessionBasic {
void AssignStaticMemory(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
void UpdateRefOutputMap(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
// member variables
// key is final_graph_id,value is child graph execute order of final graph
std::unordered_map<GraphId, std::vector<GraphId>> graph_execute_orders_;
// key is final_graph_id,value is the graph types of child graphs
std::unordered_map<GraphId, std::vector<GraphType>> graph_order_types_;
// record condition graph of while
std::unordered_map<GraphId, GraphId> while_condition_graphs_;
// record all conditions
std::unordered_map<GraphId, std::pair<GraphId, GraphId>> switches_;
std::unordered_map<GraphId, AnfNodePtr> condition_output_;
// share parameters
std::vector<std::tuple<AnfNodePtr, GraphId, size_t>> assigns_;
// initial tensors, these tensor will sync data to device before run graph
......
......@@ -108,7 +108,7 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
kernel_graph->set_execution_order(execution_order);
NamedSummaryOutputs summary_outputs;
if (enable_summary) {
GetSummaryNodes(kernel_graph.get());
SetSummaryNodes(kernel_graph.get());
summary_outputs = kernel_graph->summary_nodes();
runtime_.IncreaseSummaryRefCount(summary_outputs);
}
......
......@@ -217,7 +217,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
Reorder(&execution_order);
graph->set_execution_order(execution_order);
// Get summary nodes.
GetSummaryNodes(graph.get());
SetSummaryNodes(graph.get());
// Remove NoOp from execution graph
opt::RemoveNopNode(graph.get());
// Set graph manager.
......
......@@ -898,27 +898,6 @@ void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodeP
std::queue<AnfNodePtr> seed_nodes;
UpdateNodeEdgeList(&seed_nodes);
}
// update graph inputs in child graph
auto it_real_inputs = std::find_if(real_inputs_.begin(), real_inputs_.end(),
[&old_anf_node](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool {
return n.first == old_anf_node.get();
});
if (it_real_inputs != real_inputs_.end()) {
// erase old parameter in map
auto old_args = it_real_inputs->second;
real_inputs_.erase(it_real_inputs);
// insert new parameter to map
auto iter = std::find_if(real_inputs_.begin(), real_inputs_.end(),
[&new_anf_node](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool {
return n.first == new_anf_node.get();
});
if (iter != real_inputs_.end()) {
MS_LOG(WARNING) << new_anf_node->DebugString() << " Already exist in real inputs, will be rewrited.";
iter->second = old_args;
} else {
real_inputs_.emplace_back(new_anf_node, old_args);
}
}
}
void KernelGraph::UpdateExecuteKernelStreamLabel() {
......@@ -953,56 +932,6 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi
return result;
}
void KernelGraph::SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &arg) {
MS_EXCEPTION_IF_NULL(parameter);
MS_EXCEPTION_IF_NULL(arg);
MS_LOG(INFO) << "Parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString();
MS_EXCEPTION_IF_NULL(parameter);
MS_EXCEPTION_IF_NULL(arg);
auto iter = std::find_if(
real_inputs_.begin(), real_inputs_.end(),
[&parameter](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool { return n.first == parameter; });
if (iter != real_inputs_.end()) {
auto &args = iter->second;
args.push_back(arg);
} else {
real_inputs_.emplace_back(parameter, std::vector<AnfNodePtr>(1, arg));
}
}
void KernelGraph::AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr<KernelGraph> &from_graph) {
unreuse_args_[arg] = from_graph;
}
void KernelGraph::UpdateCallRealInput() {
MS_LOG(INFO) << "Update graph id: " << graph_id_;
std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_map;
for (auto &it : real_inputs_) {
auto parameter = it.first;
MS_EXCEPTION_IF_NULL(parameter);
auto real_inputs = it.second;
std::vector<AnfNodePtr> new_real_inputs;
for (auto &real_input : real_inputs) {
// if real input is a call node ,find the child graph output act as the new real input
auto tmp_real_input = GetCallRealOutputs(real_input);
std::copy(tmp_real_input.begin(), tmp_real_input.end(), std::back_inserter(new_real_inputs));
// replace the call in unreuse_args_
auto unreuse_arg_it = unreuse_args_.find(real_input);
if (unreuse_arg_it != unreuse_args_.end()) {
auto old_graph = unreuse_arg_it->second;
for (auto new_real_input : new_real_inputs) {
// if call reference graph output is parameter, it will be allowed to reuse
if (!new_real_input->isa<Parameter>()) {
unreuse_args_[new_real_input] = old_graph;
}
}
}
}
real_inputs_map.emplace_back(parameter, new_real_inputs);
}
real_inputs_ = real_inputs_map;
}
void KernelGraph::PrintGraphExecuteOrder() const {
MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order";
for (size_t i = 0; i < execution_order_.size(); i++) {
......
......@@ -131,16 +131,8 @@ class KernelGraph : public FuncGraph {
void set_parent_graph(const std::shared_ptr<KernelGraph> &parent_graph) { parent_graph_ = parent_graph; }
// find anf node in graph
std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const;
// get real inputs
const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs() const { return real_inputs_; }
void SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &arg);
// mark unreused args
void AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr<KernelGraph> &from_graph);
const std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> &unreuse_args() const { return unreuse_args_; }
// used to dump ir
std::string ToString() const override;
// update the real input if the node is a call
void UpdateCallRealInput();
void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; }
CNodePtr get_start_label() { return start_label_; }
......@@ -212,9 +204,6 @@ class KernelGraph : public FuncGraph {
// valid inputs
std::vector<bool> valid_inputs_;
// new members for control sink process
// all child grahs refers to partial node
std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> node_to_child_graphs_;
// child graph execute order in root graph
std::vector<std::shared_ptr<KernelGraph>> child_graph_order_;
......@@ -223,9 +212,6 @@ class KernelGraph : public FuncGraph {
// parameter graph
std::shared_ptr<KernelGraph> parent_graph_;
// record real parameters,inputs_ is the formal parameters
std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_;
std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> unreuse_args_;
CNodePtr start_label_;
CNodePtr end_goto_;
......
......@@ -890,7 +890,7 @@ void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); }
void SessionBasic::GetSummaryNodes(KernelGraph *graph) {
void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
MS_LOG(DEBUG) << "Update summary Start";
MS_EXCEPTION_IF_NULL(graph);
if (!graph->summary_node_exist()) {
......@@ -930,7 +930,7 @@ void SessionBasic::Summary(KernelGraph *graph) {
if (!exist_summary) {
return;
}
GetSummaryNodes(graph);
SetSummaryNodes(graph);
auto summary_outputs = graph->summary_nodes();
std::map<std::string, tensor::TensorPtr> params_list;
// fetch outputs apply kernel in session & run callback functions
......
......@@ -92,19 +92,9 @@ class SessionBasic {
CNodePtr HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph);
std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph);
// set parameters of final graph
virtual GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &) { return kInvalidGraphId; }
// set output of final graph
virtual void SetFinalGraphOutput(const BaseRef &) {}
// insert switch and set the relative active ops
virtual void SwitchCompile(GraphId, GraphId, GraphId, const AnfNodePtr &) {}
// set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter
virtual void SetChildGraphInput(GraphId, const VectorRef &) {}
// get graph id in child graphs by ME front anf node pointer
virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; }
virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; }
virtual void SetActive(GraphId, GraphId) {}
virtual void GetSummaryNodes(KernelGraph *graph);
void AssignParamKey(const KernelGraphPtr &kernel_graph);
void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &inputs_const);
virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const { return true; }
......@@ -120,6 +110,7 @@ class SessionBasic {
#endif
protected:
virtual void SetSummaryNodes(KernelGraph *graph);
// Get graph by graph id ,if not exist return null ptr
KernelGraphPtr GetGraph(GraphId graph_id) const;
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
......
file(GLOB_RECURSE _PYNATIVE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "base.cc" "pynative_execute.cc")
file(GLOB_RECURSE _PYNATIVE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pynative_execute.cc")
if (ENABLE_GE)
file(GLOB_RECURSE _GE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pynative_execute_ge.cc")
......
......@@ -21,7 +21,6 @@
#include "utils/log_adapter.h"
#include "ir/anf.h"
#include "utils/callbacks.h"
#include "utils/graph_utils.h"
#include "utils/base_ref_extends.h"
#include "backend/session/session_factory.h"
#include "common/utils.h"
......@@ -34,19 +33,6 @@ namespace compile {
bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); }
bool Backend::GetIndex(const BaseRef &c, int *const value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); }
LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) {
// multi_graph merge to one, big graph have paramters in begin and only have one output
MS_LOG(DEBUG) << "graph:" << g->ToString() << " parameter size:" << g->parameters().size();
multi_result_.inputs = g->parameters();
final_output_ = NewValueNode("fake_output");
multi_result_.outputs = {final_output_};
GraphId final_g = target_sess_->GetFinalRunGraph();
multi_result_.run = std::make_shared<RunFunc>(
[final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args, ""); });
return multi_result_;
}
LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::string &target) {
MS_LOG(DEBUG) << "MsConvert";
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
......@@ -96,149 +82,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri
return result;
}
void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) {
GraphId active_g = simu_cond_map_[c].cond_graph_map[cond];
GraphId cond_g = kInvalidGraphId;
if (utils::isa<AnfNodePtr>(c)) {
cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(c));
} else {
MS_LOG(EXCEPTION) << "cond not a anf node:" << c.ToString();
}
auto before_cond = curr_switch_;
if (curr_switch_.hash() != c.hash()) {
// invoke while false->before true call
if (simu_cond_map_[before_cond].cond_graph_map.count(false)) {
active_g = simu_cond_map_[before_cond].cond_graph_map[false];
} else {
active_g = kInvalidGraphId;
}
// while x < y:
// z = y + 1
// while z < c2:
// out = out + 1
// z = z + 1
if (active_g == cond_g) {
active_g = kInvalidGraphId;
simu_cond_map_[before_cond].cond_graph_map[false] = kInvalidGraphId;
}
MS_LOG(DEBUG) << "invoke set active:" << active_g;
}
MS_LOG(DEBUG) << "switch set active:" << active_g << ", " << cond_g;
target_sess_->SetActive(active_g, cond_g);
}
void MsBackend::SetSwitchGraph() {
MS_LOG(DEBUG) << "SetSwitchGraph curr_switch:" << curr_switch_.ToString();
if (is_switch_call_) {
GraphId false_g = kInvalidGraphId;
GraphId true_g = kInvalidGraphId;
MS_LOG(DEBUG) << "start SetSwitchGraph";
true_g = simu_cond_map_[curr_switch_].cond_graph_map[true];
bool curr_cond = simu_cond_map_[curr_switch_].curr_cond;
if (!curr_cond) {
if (simu_cond_map_[curr_switch_].cond_graph_map.count(curr_cond)) {
// has false branch
false_g = simu_cond_map_[curr_switch_].cond_graph_map[false];
}
GraphId cond_g = kInvalidGraphId;
if (utils::isa<AnfNodePtr>(curr_switch_)) {
cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(curr_switch_));
} else {
MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString();
}
MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g;
target_sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_));
}
is_switch_call_ = false;
MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_;
}
}
// convert node from formal parameter to actual parameter,
// and actual parameter is graph user's formal parameter.
// get top while graph's parameter in recall while.
AnfNodePtr MsBackend::ConvertGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
std::unordered_map<AnfNodePtr, size_t> params_index;
auto result = node;
auto graph = result->func_graph();
while (func_graph != graph) {
auto iter = graph_user_inputs_.find(graph);
if (iter == graph_user_inputs_.end()) {
break;
}
params_index.clear();
auto &params = graph->parameters();
for (size_t i = 0; i < params.size(); ++i) {
params_index[params[i]] = i;
}
graph = iter->second.first;
auto &inputs = iter->second.second;
result = inputs[params_index[result]];
}
return result;
}
void MsBackend::SetGraphUserInputs(const FuncGraphPtr &func_graph, const FuncGraphPtr &user,
const AnfNodePtrList &inputs) {
if (graph_user_inputs_.find(func_graph) != graph_user_inputs_.end()) {
return;
}
graph_user_inputs_[func_graph] = {user, inputs};
}
void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef &args, const BaseRef &c) {
std::unordered_map<AnfNodePtr, size_t> params_index;
auto &params = func_graph->parameters();
for (size_t i = 0; i < params.size(); ++i) {
params_index[params[i]] = i;
}
// recall all child graphs in this while
auto &graph_inputs = graph_inputs_[c];
for (auto &iter : graph_inputs) {
auto &graph = iter.first;
auto &old_args = iter.second;
auto &result = graph_id_map_[graph];
auto &inputs = result.inputs;
for (size_t i = 0; i < inputs.size(); ++i) {
auto input = ConvertGraphInput(func_graph, inputs[i]);
auto it = params_index.find(input);
if (it != params_index.end()) {
old_args[i] = args[it->second];
}
}
target_sess_->SetChildGraphInput(graph, old_args);
}
graph_inputs_.erase(c);
}
// compile set input output
VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) {
MS_LOG(DEBUG) << "set graph input:" << g;
// switch maybe twice
target_sess_->SetChildGraphInput(g, args);
if (is_switch_call_) {
if (!curr_switch_.is_null()) {
// push this {g, args} to all user while graph_inputs for nest while,
// when current condition recall over delete this cond in graph_inputs.
for (auto &iter : graph_inputs_) {
iter.second.push_back({g, args});
}
if (graph_inputs_.find(curr_switch_) == graph_inputs_.end()) {
graph_inputs_[curr_switch_].push_back({g, args});
}
}
bool curr_cond = simu_cond_map_[curr_switch_].curr_cond;
MS_LOG(DEBUG) << "switch call MsSimuRunGraph:" << curr_cond << ", " << g;
simu_cond_map_[curr_switch_].cond_graph_map[curr_cond] = g;
SetSwitchGraph();
}
std::vector<BaseRef> outputs;
(void)std::transform(graph_id_map_[g].outputs.begin(), graph_id_map_[g].outputs.end(), std::back_inserter(outputs),
[](const AnfNodePtr &v) { return v; });
......@@ -290,36 +136,6 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s
return outputs;
}
SwitchCondStatus MsBackend::SetSimuCond(const BaseRef &c, bool value) {
MS_LOG(DEBUG) << "set cond :" << c.ToString() << ", " << simu_cond_map_.size();
CondGraph cond_graph;
cond_graph.curr_cond = value;
if (simu_cond_map_.find(c) == simu_cond_map_.end()) {
simu_cond_map_[c] = cond_graph;
}
if (simu_cond_map_[c].cond_graph_map.count(value)) {
return kCondAlreadyRun;
}
simu_cond_map_[c].curr_cond = value;
MS_LOG(DEBUG) << "end set cond ";
return kCondOk;
}
void MsBackend::SimulateRun(FinalVMPtr rt, FuncGraphPtr root) {
MS_LOG(DEBUG) << "Simulate run,root:" << root->ToString() << ", " << root->parameters().size();
std::vector<BaseRef> args;
auto parameters = root->parameters();
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args),
[](const AnfNodePtr &v) { return v; });
MS_LOG(DEBUG) << "Simulate start";
(void)target_sess_->SetFinalGraphInput(parameters);
BaseRef output = rt->Eval(VectorRef(args));
target_sess_->SetFinalGraphOutput(output);
MS_LOG(DEBUG) << "Simulate Eval end";
}
void MsBackend::Link(GraphId graph_id) {
if (graph_id == kInvalidGraphId) {
graph_id = target_sess_->GetFinalRunGraph();
......@@ -330,9 +146,6 @@ void MsBackend::Link(GraphId graph_id) {
Backend::Backend(const std::string &name) : name_(name) {
MS_LOG(DEBUG) << "select backend:" << name;
convert_fn_ = backends[name_];
is_switch_call_ = false;
is_multi_graph_sink_ = false;
simu_flag_ = false;
}
MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
......
......@@ -43,50 +43,19 @@ class Backend {
LinkFuncType convert_fn() { return convert_fn_; }
std::string name() { return name_; }
virtual void SimulateRun(FinalVMPtr, FuncGraphPtr) {}
virtual SwitchCondStatus SetSimuCond(const BaseRef &, bool) { return kCondOk; }
virtual bool GetCond(const BaseRef &c, bool *value);
virtual bool GetIndex(const BaseRef &c, int *value);
virtual void SetSwitchGraph() {}
virtual void SetSwitchActive(const BaseRef &, bool) {}
virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {}
virtual void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) {}
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> fg) { return kInvalidGraphId; }
void set_curr_switch(const BaseRef &value) {
curr_switch_ = value;
is_switch_call_ = true;
}
BaseRef curr_switch() { return curr_switch_; }
virtual void Link(GraphId) {}
virtual LinConvertResult GetMultiGraphRun(const FuncGraphPtr &) { return LinConvertResult(); }
virtual void SetDebugger() {}
LinConvertResult multi_result() { return multi_result_; }
void set_multi_result(const LinConvertResult &value) { multi_result_ = value; }
AnfNodePtr final_output() const { return final_output_; }
bool is_multi_graph_sink() const { return is_multi_graph_sink_; }
void set_is_multi_graph_sink(bool flag) { is_multi_graph_sink_ = flag; }
bool simu_flag() const { return simu_flag_; }
bool is_switch_call() const { return is_switch_call_; }
void set_simu_flag(bool simu) { simu_flag_ = simu; }
virtual void SetDebugger() {}
protected:
std::string name_;
LinkFuncType convert_fn_;
BaseRef curr_switch_; // curr switch node
bool is_multi_graph_sink_;
bool is_switch_call_;
bool simu_flag_;
LinConvertResult multi_result_;
AnfNodePtr final_output_;
std::unordered_map<FuncGraphPtr, std::pair<FuncGraphPtr, AnfNodePtrList>> graph_user_inputs_;
};
struct CondGraph {
bool curr_cond;
std::unordered_map<bool, GraphId> cond_graph_map;
};
class MsBackend : public Backend {
......@@ -98,16 +67,7 @@ class MsBackend : public Backend {
VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = "");
VectorRef MsSimuRunGraph(const GraphId &g, const VectorRef &args);
void SimulateRun(FinalVMPtr rt, FuncGraphPtr root) override;
SwitchCondStatus SetSimuCond(const BaseRef &c, bool value) override;
void SetSwitchGraph() override;
void SetSwitchActive(const BaseRef &c, bool cond) override;
void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) override;
void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) override;
void Link(GraphId) override;
AnfNodePtr ConvertGraphInput(const FuncGraphPtr &, const AnfNodePtr &);
LinConvertResult GetMultiGraphRun(const FuncGraphPtr &g) override;
GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override;
VectorRef RunGraph(GraphId graph_id, const VectorRef &args);
void CreateOtherSession(const std::string &target);
......@@ -121,9 +81,7 @@ class MsBackend : public Backend {
session::SessionPtr other_sess_;
std::string target_device_;
std::string other_device_;
std::unordered_map<BaseRef, CondGraph, BaseRefHash> simu_cond_map_;
std::unordered_map<GraphId, LinConvertResult> graph_id_map_;
std::unordered_map<BaseRef, std::list<std::pair<GraphId, VectorRef>>, BaseRefHash> graph_inputs_;
};
} // namespace compile
} // namespace mindspore
......
......@@ -515,11 +515,7 @@ int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &no
MS_LOG(DEBUG) << "LinConvert start";
LinConvertResult result;
if (backend_->simu_flag()) {
result = backend_->GetMultiGraphRun(graph);
} else {
result = lin_convert_(node_list, target);
}
result = lin_convert_(node_list, target);
if (result.run == nullptr) {
MS_LOG(ERROR) << "LinConvert failed";
......@@ -546,27 +542,6 @@ int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &no
return RET_SUCCESS;
}
void CompileGraph::AddSinkSwitch(const CNodePtr &node) {
MS_LOG(DEBUG) << "AddSinkSwitch:" << node->ToString();
if (backend_->is_multi_graph_sink()) {
VectorRef args;
args.emplace_back(-1);
MS_LOG(DEBUG) << "call::" << height_;
AddInst(Instruction::kCall, args);
args.clear();
args.emplace_back(node->input(1));
AddInst(Instruction::kSwitchReturn, args);
args.clear();
args.emplace_back(false);
args.emplace_back(Ref(node->input(1)));
args.emplace_back(Ref(node->input(2)));
args.emplace_back(Ref(node->input(3)));
AddInst(Instruction::kSwitch, args);
}
}
int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Interpret node: " << node->DebugString(true);
......@@ -589,7 +564,6 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node)
AddPartial(node);
} else if (IsPrimitive(fn, prim::kPrimSwitch)) {
AddSwitch(node);
AddSinkSwitch(node);
} else if (IsPrimitive(fn, prim::kPrimSwitchLayer)) {
AddSwitchLayer(node);
} else if (IsPrimitive(fn, prim::kPrimMakeTuple)) {
......@@ -607,14 +581,6 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node)
return RET_SUCCESS;
}
void CompileGraph::GenMultiGraphsRun(const FuncGraphPtr &graph) {
auto ret = LinConvert(graph, {});
if (ret == RET_FAILED) {
MS_LOG(EXCEPTION) << "MultiGraphRun failed.";
}
AddReturn(nullptr);
}
bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) {
MS_LOG(DEBUG) << "Start split graph";
MS_EXCEPTION_IF_NULL(graph);
......@@ -659,11 +625,6 @@ bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) {
return true;
}
InstSet CompileGraph::GenMultiGraphsSinkInst(const FuncGraphPtr &graph) {
InstSet inst = Run(graph);
return inst;
}
InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
......@@ -672,12 +633,8 @@ InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
int param_height = height_;
MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true);
if (backend_->simu_flag()) {
GenMultiGraphsRun(graph);
} else {
if (!SplitGraph(graph)) {
return inst_;
}
if (!SplitGraph(graph)) {
return inst_;
}
AddPadStack(param_height);
......@@ -712,12 +669,6 @@ void CompileGraph::AddPartial(const CNodePtr &node) {
if (!IsValueNode<FuncGraph>(fn)) {
MS_LOG(EXCEPTION) << "The type of 1st input of node must be FuncGraph";
}
if (backend_->is_multi_graph_sink()) {
auto func_graph = GetValueNode<FuncGraphPtr>(fn);
args.emplace_back(func_graph);
AnfNodePtrList outs(inputs.begin() + 2, inputs.end());
backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs);
}
for (size_t i = 1; i < inputs.size(); i++) {
args.emplace_back(Ref(inputs[i]));
}
......@@ -739,9 +690,6 @@ void CompileGraph::AddSwitch(const CNodePtr &node) {
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4";
}
VectorRef args;
if (backend_->is_multi_graph_sink()) {
args.emplace_back(true);
}
args.emplace_back(Ref(inputs[1]));
args.emplace_back(Ref(inputs[2]));
args.emplace_back(Ref(inputs[3]));
......@@ -761,11 +709,7 @@ void CompileGraph::AddSwitchLayer(const CNodePtr &node) {
void CompileGraph::AddReturn(const CNodePtr &node) {
VectorRef args;
if (backend_->simu_flag()) {
args.emplace_back(Ref(backend_->final_output()));
} else {
args.emplace_back(Ref(node->input(1)));
}
args.emplace_back(Ref(node->input(1)));
args.emplace_back(height_);
AddInst(Instruction::kReturn, args);
}
......@@ -783,11 +727,6 @@ void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim)
int CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) {
auto inputs = node->inputs();
AnfNodePtr fn = inputs[0];
if (backend_->is_multi_graph_sink() && IsValueNode<FuncGraph>(fn)) {
auto func_graph = GetValueNode<FuncGraphPtr>(fn);
AnfNodePtrList outs(inputs.begin() + 1, inputs.end());
backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs);
}
(void)Ref(fn);
size_t size = inputs.size();
for (size_t i = size - 1; i > 0; i--) {
......@@ -929,17 +868,6 @@ FinalVMPtr CompileGraphs::Link(const FuncGraphPtr &graph) {
}
FinalVMPtr rt = std::make_shared<FinalVM>(insts_, backend_);
if (backend_->is_multi_graph_sink()) {
backend_->set_simu_flag(true);
MS_LOG(DEBUG) << "Start simulate";
backend_->SimulateRun(rt, graph);
MS_LOG(DEBUG) << "Link graphs";
insts_ = transform_->GenMultiGraphsSinkInst(graph);
rt->set_insts(insts_);
backend_->set_simu_flag(false);
MS_LOG(DEBUG) << "End start simulate";
backend_->Link(kInvalidGraphId);
}
MS_LOG(DEBUG) << "End";
return rt;
}
......
......@@ -54,12 +54,10 @@ class CompileGraph {
~CompileGraph() = default;
InstSet Run(const FuncGraphPtr &func_graph);
InstSet GenMultiGraphsSinkInst(const FuncGraphPtr &graph);
bool IsCut(const AnfNodePtr &node);
void Push(const AnfNodePtr &node);
void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; }
void Ret(int nargs);
void GenMultiGraphsRun(const FuncGraphPtr &graph);
int Ref(const AnfNodePtr &node);
VectorRef SplitNodes(const FuncGraphPtr &func_graph);
......@@ -84,7 +82,6 @@ class CompileGraph {
int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = "");
int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);
int AddCall(const FuncGraphPtr &graph, const CNodePtr &node);
void AddSinkSwitch(const CNodePtr &node);
void AddPadStack(int param_height);
void AddTailCall(const AnfNodePtr &fn, size_t size);
void AddPartial(const CNodePtr &node);
......
......@@ -17,12 +17,9 @@
*/
#include "vm/vm.h"
#include <algorithm>
#include "vm/vmimpl.h"
#include "vm/backend.h"
#include "vm/transform.h"
#include "pipeline/jit/parse/data_converter.h"
#include "utils/base_ref_extends.h"
......@@ -142,33 +139,10 @@ void FinalVM::Popsp() {
}
}
void FinalVM::PushStatus(bool is_switch_call) { ret_status_.push(is_switch_call); }
bool FinalVM::PopStatus() {
if (ret_status_.empty()) {
return false;
}
bool status = ret_status_.top();
ret_status_.pop();
return status;
}
void FinalVM::DoJmp(const BaseRef &jmp_orig) {
MS_LOG(DEBUG) << "Start";
BaseRef jmp = jmp_orig;
if (backend_->simu_flag()) {
bool is_switch_call = false;
if (utils::isa<StructSimuSwitch>(jmp)) { // need to inherit from Base
MS_LOG(DEBUG) << "Start jump StructSwitch";
auto simu_value = utils::cast<std::shared_ptr<StructSimuSwitch>>(jmp);
jmp = simu_value->fn_;
backend_->set_curr_switch(simu_value->value_);
is_switch_call = true;
}
PushStatus(is_switch_call);
}
if (utils::isa<StructPartial>(jmp)) { // need to inherit from Base
MS_LOG(DEBUG) << "Start jump StructPartial";
auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp);
......@@ -270,13 +244,6 @@ void FinalVM::InstSwitchReturn(const VectorRef &args) {
MS_LOG(ERROR) << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << ".";
return;
}
auto rv = Ref(-1);
if (utils::isa<AnfNodePtr>(rv) || utils::isa<VectorRef>(rv)) {
auto &c = args[0];
cond_out_[c] = rv;
}
Pop(1);
Popsp();
}
......@@ -294,51 +261,12 @@ void FinalVM::InstReturn(const VectorRef &args) {
int height = utils::cast<int>(args[1]);
auto rv = Ref(rpos);
if (backend_->simu_flag()) {
auto c = backend_->curr_switch();
auto status = PopStatus();
if (status) {
auto iter = cond_out_.find(c);
if (iter != cond_out_.end()) {
rv = MergeArgs(rv, iter->second);
cond_out_.erase(iter);
}
}
if (backend_->is_switch_call()) {
backend_->SetSwitchGraph();
}
}
Pop(height);
Push(rv);
Popp();
MS_LOG(DEBUG) << "End";
}
void FinalVM::InstSimuPartial(const VectorRef &args) {
const size_t args_size = 2;
if (args.size() < args_size) {
MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " or more parameters, while the input size is "
<< args.size() << ".";
return;
}
auto &node = args[0];
if (!utils::isa<FuncGraphPtr>(node)) {
MS_LOG(ERROR) << "The type of 1st input of node must be FuncGraph";
return;
}
auto fg = utils::cast<FuncGraphPtr>(node);
int fn_ = utils::cast<int>(args[1]);
auto fn = utils::cast<int>(Ref(fn_));
MS_LOG(DEBUG) << "Partial argssize:" << args.size();
std::vector<BaseRef> outs(args.size() - 2);
(void)std::transform(args.begin() + 2, args.end(), outs.begin(),
[&, this](const BaseRef &a) { return Ref(utils::cast<int>(a)); });
Push(std::make_shared<StructPartial>(fn, VectorRef(outs), fg));
}
void FinalVM::InstRealPartial(const VectorRef &args) {
const size_t args_size = 1;
if (args.size() < args_size) {
......@@ -358,91 +286,10 @@ void FinalVM::InstRealPartial(const VectorRef &args) {
void FinalVM::InstPartial(const VectorRef &args) {
MS_LOG(DEBUG) << "Start";
if (backend_->is_multi_graph_sink()) {
InstSimuPartial(args);
} else {
InstRealPartial(args);
}
InstRealPartial(args);
MS_LOG(DEBUG) << "End";
}
void FinalVM::InstSimuSwitch(const VectorRef &args) {
const size_t args_size = 4;
if (args.size() != args_size) {
MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
<< ".";
return;
}
bool cond = utils::cast<bool>(args[0]);
int cond_node = utils::cast<int>(args[1]);
int vtrue = utils::cast<int>(args[2]);
int vfalse = utils::cast<int>(args[3]);
MS_LOG(DEBUG) << "Simu switch cond:" << cond;
BaseRef c = Ref(cond_node);
bool bool_value = cond;
SwitchCondStatus cond_stat = backend_->SetSimuCond(c, bool_value);
if (cond_stat == kCondAlreadyRun) {
MS_LOG(DEBUG) << "switch alreay run bool while true jmp";
BaseRef jmp = Ref(vtrue);
if (utils::isa<StructPartial>(jmp)) {
auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp);
backend_->RecallGraphInput(new_jmp->fg_, new_jmp->args_, c);
}
cond_jmp_[c] = Ref(vfalse);
Push(static_cast<int>(cond_stat));
Popp();
backend_->SetSwitchActive(c, bool_value);
return;
}
if (bool_value) {
Push(std::make_shared<StructSimuSwitch>(Ref(vtrue), c));
Pushsp();
} else {
MergeJmpArgs(Ref(vfalse), c);
Push(std::make_shared<StructSimuSwitch>(Ref(vfalse), c));
}
}
void FinalVM::MergeJmpArgs(const BaseRef &jmp, const BaseRef &c) {
auto iter = cond_jmp_.find(c);
if (iter == cond_jmp_.end()) {
return;
}
auto old_jmp = utils::cast<std::shared_ptr<StructPartial>>(iter->second);
auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp);
auto &old_args = old_jmp->args_;
auto &new_args = new_jmp->args_;
for (size_t i = 0; i < new_args.size(); ++i) {
auto &old_arg = old_args[i];
auto &new_arg = new_args[i];
new_arg = MergeArgs(old_arg, new_arg);
}
}
BaseRef FinalVM::MergeArgs(const BaseRef &first, const BaseRef &second) {
MS_LOG(DEBUG) << __FUNCTION__ << ": " << first.ToString() << ", " << second.ToString();
if (utils::isa<VectorRef>(first)) {
auto old_vec_ref = utils::cast<VectorRef>(first);
if (utils::isa<VectorRef>(second)) {
auto new_vec_ref = utils::cast<VectorRef>(second);
std::copy(new_vec_ref.begin(), new_vec_ref.end(), std::back_inserter(old_vec_ref));
} else {
old_vec_ref.push_back(second);
}
return old_vec_ref;
}
if (utils::isa<VectorRef>(second)) {
auto new_vec_ref = utils::cast<VectorRef>(second);
new_vec_ref.push_back(first);
return new_vec_ref;
}
return VectorRef({first, second});
}
void FinalVM::InstRealSwitch(const VectorRef &args) {
const size_t args_size = 3;
if (args.size() != args_size) {
......@@ -472,11 +319,7 @@ void FinalVM::InstRealSwitch(const VectorRef &args) {
void FinalVM::InstSwitch(const VectorRef &args) {
MS_LOG(DEBUG) << "Start";
if (backend_->is_multi_graph_sink()) {
InstSimuSwitch(args);
} else {
InstRealSwitch(args);
}
InstRealSwitch(args);
MS_LOG(DEBUG) << "End";
}
......@@ -580,14 +423,6 @@ void FinalVM::InstExternal(const VectorRef &args) {
VectorRef tuple;
RunFunctionRef run_ref = utils::cast<RunFunctionRef>(args[0]);
compile::RunFuncPtr fn = run_ref.func_;
if (backend_->simu_flag()) {
MS_LOG(DEBUG) << "Simu run";
if (args.size() == 1) {
MS_LOG(EXCEPTION) << "The number of args should be greater than 1, but got 1";
}
auto simu_run_ref = utils::cast<RunFunctionRef>(args[1]);
fn = simu_run_ref.func_;
}
for (size_t i = 2; i < args.size(); ++i) {
auto index = utils::cast<int>(args[i]);
tuple.push_back(Ref(index));
......
......@@ -96,7 +96,6 @@ class FinalVM {
public:
// Create a VM with the specified instructions and backend.
explicit FinalVM(const InstSet &insts, const BackendPtr &backend);
virtual ~FinalVM() = default;
BaseRef Eval(const VectorRef &args);
......@@ -104,10 +103,8 @@ class FinalVM {
void InstTailCall(const VectorRef &args);
void InstReturn(const VectorRef &args);
void InstPartial(const VectorRef &args);
void InstSimuPartial(const VectorRef &args);
void InstRealPartial(const VectorRef &args);
void InstSwitch(const VectorRef &args);
void InstSimuSwitch(const VectorRef &args);
void InstRealSwitch(const VectorRef &args);
void InstTuple(const VectorRef &args);
void InstPush(const VectorRef &args);
......@@ -129,23 +126,16 @@ class FinalVM {
void Popp();
void Pushsp();
void Popsp();
void PushStatus(bool is_switch_call);
bool PopStatus();
void DoJmp(const BaseRef &jmp);
void SyncData(const py::object &args);
void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c);
BaseRef MergeArgs(const BaseRef &first, const BaseRef &second);
private:
InstSet insts_;
std::deque<BaseRef> insts_stack_;
std::stack<int> retp_;
std::stack<int> retsp_;
std::stack<bool> ret_status_;
int pc_;
int sp_;
std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_jmp_;
std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_out_;
BackendPtr backend_;
const InstFunctionMap inst_function_map = {
{Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册