提交 8003a89a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!766 bugfix(SA): Add the support of nested loop.

Merge pull request !766 from gongchen/nest_loop
......@@ -27,14 +27,13 @@
#include <utility>
#include <initializer_list>
#ifdef DEBUG
#include "debug/draw.h"
#include "debug/anf_ir_dump.h"
#endif
#include "debug/trace.h"
#include "optimizer/opt.h"
#include "pipeline/resource.h"
#include "pipeline/action.h"
#include "debug/trace.h"
#include "utils/context/ms_context.h"
namespace mindspore {
namespace opt {
......@@ -133,7 +132,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) {
// Optimizer step counter;
int counter = 1;
int counter = -1;
bool changes = true;
while (changes) {
......@@ -170,13 +169,14 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
}
};
use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func();
#ifdef DEBUG
MS_LOG(DEBUG) << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end.";
auto fg_name = name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];
func_graph->DumpFuncGraph(fg_name);
DumpIR(fg_name + ".ir", func_graph);
MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph.";
#endif
if (IS_OUTPUT_ON(mindspore::DEBUG) && MsContext::GetInstance()->save_graphs_flag()) {
MS_LOG(DEBUG) << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end.";
auto fg_name =
"opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];
func_graph->DumpFuncGraph(fg_name);
DumpIR(fg_name + ".ir", func_graph);
MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph.";
}
}
};
use_profile ? (WITH(MsProfile::GetProfile()->Lap(counter++)) run_runc) : run_runc();
......
......@@ -32,6 +32,7 @@
#include "pipeline/static_analysis/static_analysis.h"
#include "pipeline/static_analysis/program_specialize.h"
#include "pipeline/resource.h"
#include "utils/context/ms_context.h"
#include "pipeline/remove_value_node_dup.h"
#include "optimizer/optimizer.h"
#include "vm/transform.h"
......@@ -240,13 +241,23 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
}
bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes) {
size_t counter = 0;
for (auto &pass : passes) {
WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res]() {
WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res, &counter]() {
MS_LOG(DEBUG) << "Pass " << pass.first << " start ...";
auto result = pass.second(res);
if (!result) {
MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first;
}
if (MsContext::GetInstance()->save_graphs_flag() && res->func_graph() != nullptr) {
auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first;
auto func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
func_graph->DumpFuncGraph(fg_name);
DumpIR(fg_name + ".ir", func_graph);
MS_LOG(DEBUG) << "Dump " << fg_name << " func graph.";
}
counter++;
MS_LOG(DEBUG) << "Pass " << pass.first << " end.";
};
}
......
......@@ -55,6 +55,7 @@ void InferFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &
AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &engine,
const AbstractBasePtrList &args_spec_list) {
AbstractBasePtrList normalized_args_spec_list = NormalizeArgs(args_spec_list);
normalized_args_spec_list = BroadenUndeterminedArgs(normalized_args_spec_list);
FuncGraphPtr fg = GetFuncGraph(engine, normalized_args_spec_list);
MS_EXCEPTION_IF_NULL(parent_context_);
AnalysisContextPtr context = parent_context_->NewFuncGraphContext(fg, normalized_args_spec_list);
......@@ -140,7 +141,14 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
<< ", broaded: " << mindspore::ToString(broaded_list);
return broaded_list;
}
return args_spec_list;
}
AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(func_graph_);
if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
return args_spec_list;
}
if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) {
if (parent_context_) {
MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString()
......@@ -160,6 +168,21 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
return joined_args_spec_list;
}
}
if (trace_.size() != 0) {
MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list);
MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(trace_.back());
// Join the last eval arguments and current arguments to check if there are loop variant.
auto joined_args_spec_list = AbstractJoin(args_spec_list, trace_.back());
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
if (!(joined_args_spec_list == args_spec_list)) {
trace_.push_back(joined_args_spec_list);
func_graph_->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
}
MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list);
return joined_args_spec_list;
} else {
trace_.push_back(args_spec_list);
}
}
return args_spec_list;
}
......@@ -224,6 +247,7 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar
return conf->GetEvaluatedValue();
});
args_spec_list = NormalizeArgs(args_spec_list);
args_spec_list = BroadenUndeterminedArgs(args_spec_list);
trace::TraceGraphInferEnter(shared_from_base<Evaluator>(), out_conf);
InferEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
MS_EXCEPTION_IF_NULL(cache_);
......
......@@ -47,6 +47,10 @@ class Evaluator : public Base {
virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; }
virtual AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) {
return args_spec_list;
}
std::string ToString() const override { return identifier_; }
virtual AnfNodePtr bound_node() const { return bound_node_.lock(); }
......@@ -181,12 +185,14 @@ class FuncGraphEvaluator : public BaseFuncGraphEvaluator {
FuncGraphPtr func_graph() { return func_graph_; }
AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override;
AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) override;
std::string ToString() const override { return identifier_ + "_" + func_graph_->ToString(); }
private:
FuncGraphPtr func_graph_;
std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>
func_graph_cache_;
std::vector<AbstractBasePtrList> trace_;
};
using FuncGraphEvaluatorPtr = std::shared_ptr<FuncGraphEvaluator>;
......
......@@ -19,6 +19,7 @@
#include "pipeline/static_analysis/static_analysis.h"
#include <algorithm>
#include <set>
#include "pipeline/static_analysis/utils.h"
#include "pipeline/static_analysis/prim.h"
......@@ -239,7 +240,6 @@ AbstractBasePtr AnalysisEngine::InferCNode(const CNodePtr &cnode, const AnfNodeC
for (std::size_t i = 1; i < inputs.size(); i++) {
const AnfNodePtr &node = inputs[i];
args_conf_list.push_back(MakeConfig(node, context));
MS_LOG(DEBUG) << "Current CNode args_conf_list[" << i << "] node: " << node->DebugString();
}
std::vector<EvaluatorPtr> infs;
......@@ -469,6 +469,10 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list) {
AbstractBasePtrList out_specs;
if (!multi_poss_.count(evaluators[0])) {
multi_poss_[evaluators[0]] = evaluators[1];
multi_poss_[evaluators[1]] = evaluators[0];
}
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr {
......@@ -478,28 +482,81 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
for (auto eval : evaluators) {
auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>();
if (fg_eval) {
auto undetermined_fgs = fg_eval->func_graph()->recursive_graphs();
auto fg = fg_eval->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto undetermined_fgs = fg->recursive_graphs();
if (undetermined_fgs) {
for (auto undetermined_fg : *undetermined_fgs) {
MS_LOG(DEBUG) << "Set graph undetermined: " << undetermined_fg->ToString();
// As the current evaluator has multiple possibles, all the func_graphs which
// are recursive with the current func_graph are undetermined in control flow.
undetermined_fg->set_flags(kFuncGraphFlagUndetermined, true);
}
auto fg_parent = fg->parent();
MS_EXCEPTION_IF_NULL(fg_parent);
fg_parent->set_flags(kFuncGraphFlagUndetermined, true);
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString();
}
}
auto current_inf = std::make_pair(eval, args_spec_list);
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
// If current evaluator is under tracing, then skip current evaluator to avoid recursively inferring.
auto it = std::find(eval_trace_.begin(), eval_trace_.end(), current_inf);
if (it == eval_trace_.end()) {
auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf);
if (it == eval_trace_.rend()) {
eval_trace_.push_back(current_inf);
MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get();
MS_EXCEPTION_IF_NULL(eval);
auto out_spec = eval->Run(shared_from_this(), args_conf_list, out_conf);
MS_EXCEPTION_IF_NULL(out_spec);
MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << out_spec->ToString();
out_specs.push_back(out_spec);
MS_LOG(DEBUG) << "Pop Evaluator " << eval->ToString();
eval_trace_.pop_back();
if (eval_trace_.empty()) {
multi_poss_.clear();
}
} else if (it != eval_trace_.rbegin()) {
// Find latest entry function to handle nested recursion.
EvaluatorPtr latest_entry = eval;
auto latest_entry_iter = eval_trace_.rbegin();
for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) {
auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first);
if (it_temp != evaluators.end()) {
latest_entry = *it_temp;
latest_entry_iter = r_it;
break;
}
latest_entry_iter = ++r_it;
}
if (latest_entry != eval) {
MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString();
continue;
}
bool has_undetermined = false;
// Check whether sub loop has untraced undetermined evaluator.
std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> undetermined_evals;
for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) {
undetermined_evals.insert(*r_it);
}
MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size();
for (auto u_eval : undetermined_evals) {
MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined.";
if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) {
MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined.";
has_undetermined = true;
break;
}
}
if (has_undetermined == false) {
MS_LOG(DEBUG) << eval->ToString() << " has no undetermined.";
continue;
}
// Try to travel the latest undetermined.
if (latest_entry != eval_trace_.rbegin()->first) {
MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString();
auto out_spec = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
MS_EXCEPTION_IF_NULL(out_spec);
MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() << " return out_spec: " << out_spec->ToString();
return out_spec;
}
}
}
if (out_specs.size() == 0) {
......
......@@ -25,6 +25,7 @@
#include <unordered_map>
#include <vector>
#include <utility>
#include <map>
#ifdef DEBUG
#include <stack>
......@@ -206,6 +207,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
AnfNodeConfigMap anfnode_config_map_;
// Use a list to trace multiple evaluators.
std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_;
std::map<EvaluatorPtr, EvaluatorPtr> multi_poss_;
AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
const ConfigPtrList &args_conf_list);
......
......@@ -39,7 +39,6 @@ from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
def setup_module(module):
context.set_context(mode=context.PYNATIVE_MODE)
@ms_function
def while_upper_bound(upper):
rval = 2
......@@ -392,6 +391,58 @@ def test_grad_factorial():
res = C.grad(factorial)(3)
assert res == 11
@ms_function
def factorial2(n):
""" factorial """
if n != 0:
return n * factorial2(n-1)
elif n == 1:
return 1 * factorial2(n-1)
else:
return 1
def test_factorial2():
res = factorial2(3)
assert res == 6
@ms_function
def foo(n):
if n <= 1:
if n == 1:
return foo(n-1)
else:
return 1
else:
return foo(n-1)
def test_foo():
res = foo(5)
assert res == 1
@ms_function
def double_nested_loop(x):
i = 0
s = 0
while(i < x):
j = 0
i = i + 1
while(j < 3):
j = j + 1
s = s + j
return s
def test_nested_loop():
res = double_nested_loop(3)
assert res == 18
@ms_function
def double_nested_loop2(x):
s = 0
for i in range(x):
for j in range(3):
s = s + j
return s
def test_nested_loop2():
res = double_nested_loop(1)
assert res == 6
def _for(x):
""" _for """
ret = x * x
......
......@@ -24,7 +24,7 @@ from mindspore.ops import operations as P
def setup_module(module):
context.set_context(mode = context.PYNATIVE_MODE, save_graphs = True, device_target = "Ascend")
context.set_context(mode = context.PYNATIVE_MODE, save_graphs = False, device_target = "Ascend")
context.set_context(enable_task_sink = True, device_id = 0)
......@@ -86,7 +86,17 @@ def while_by_while(x, y, z):
x = x + 1
x = x + 1
return x
@ms_function
def while_in_while(x, y, z):
out = c4
while x < y:
z = c4 + c4
while z < y:
z = z + 1
out = out + z
x = x + 1
out = out + x
return out
def test_simple_if():
output = simple_if(c1, c2, c3)
......@@ -117,3 +127,7 @@ def test_while_by_while():
expect = Tensor([28], mstype.int32)
assert output == expect
def test_while_in_while():
output = while_in_while(c1, c2, c3)
expect = Tensor([1274], mstype.int32)
assert output == expect
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册