提交 066f20e7 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!169 fix the method to calculate the children graph

Merge pull request !169 from xychow/fix-manager-children-issue
...@@ -985,40 +985,14 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) { ...@@ -985,40 +985,14 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) {
} }
} }
// children include:
// A. func graphs which use variables in fg as free variables; (child_direct_)
// B. func graphs which call func func graph in A. (all_users_)
FuncGraphSetPtr ChildrenComputer::SeekChildren(const FuncGraphPtr& fg, const FuncGraphSetPtr& path) {
if (path == nullptr || path->contains(fg)) {
return std::make_shared<FuncGraphSet>();
}
std::shared_ptr<FuncGraphSet> children = std::make_shared<FuncGraphSet>();
auto& deps = *child_direct_;
auto& users = *all_users_;
MS_LOG(DEBUG) << "" << fg->ToString() << " start func graph dep size:" << deps[fg].size();
for (auto& dep : deps[fg]) {
FuncGraphPtr child = dep.first;
children->add(child);
path->add(child);
MS_LOG(DEBUG) << "Child func graph:" << fg->ToString() << " child " << child->ToString();
for (auto& user : users[child]) {
auto user_func_graph = user.first;
MS_LOG(DEBUG) << "Func graph:" << fg->ToString() << " user " << user_func_graph->ToString();
children->add(user_func_graph);
path->add(user_func_graph);
}
children->update(SeekChildren(child, path));
}
(void)children->erase(fg);
MS_LOG(DEBUG) << "End in children: " << children->size();
return children;
}
void ChildrenComputer::RealRecompute(FuncGraphPtr fg) { void ChildrenComputer::RealRecompute(FuncGraphPtr fg) {
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
child_direct_ = &manager_->func_graph_child_direct(); auto used_fg_total = manager_->func_graphs_used_total(fg);
all_users_ = &manager_->func_graph_users(); for (auto& used_fg : used_fg_total) {
children_analysis_[fg].update(SeekChildren(fg)); if (manager_->parent(used_fg) == fg) {
children_analysis_[fg].add(used_fg);
}
}
} }
void ScopeComputer::RealRecompute(FuncGraphPtr fg) { void ScopeComputer::RealRecompute(FuncGraphPtr fg) {
......
...@@ -398,11 +398,8 @@ class ParentComputer final : public DepComputer { ...@@ -398,11 +398,8 @@ class ParentComputer final : public DepComputer {
// graph's children graph except self // graph's children graph except self
class ChildrenComputer final : public DepComputer { class ChildrenComputer final : public DepComputer {
public: public:
explicit ChildrenComputer(const FuncGraphManager* m) : DepComputer(m), all_users_(nullptr), child_direct_(nullptr) {} explicit ChildrenComputer(const FuncGraphManager* m) : DepComputer(m) {}
~ChildrenComputer() override { ~ChildrenComputer() override = default;
all_users_ = nullptr;
child_direct_ = nullptr;
}
FuncGraphToFuncGraphSetMap& children_analysis() { return children_analysis_; } FuncGraphToFuncGraphSetMap& children_analysis() { return children_analysis_; }
...@@ -414,13 +411,6 @@ class ChildrenComputer final : public DepComputer { ...@@ -414,13 +411,6 @@ class ChildrenComputer final : public DepComputer {
void ExtraReset() override { children_analysis_.clear(); } void ExtraReset() override { children_analysis_.clear(); }
void RealRecompute(FuncGraphPtr fg) override; void RealRecompute(FuncGraphPtr fg) override;
private:
FuncGraphSetPtr SeekChildren(const FuncGraphPtr& fg, const FuncGraphSetPtr& path = std::make_shared<FuncGraphSet>());
// when SeekChildren calls itself recursively, it can access these variables by class member
// other than pass by formal parameters, it can save 2 parameters for SeekChildren().
FuncGraphToFuncGraphCounterMap* all_users_;
FuncGraphToFuncGraphCounterMap* child_direct_;
}; };
// graph's children graph include self // graph's children graph include self
......
...@@ -38,16 +38,6 @@ def setup_module(module): ...@@ -38,16 +38,6 @@ def setup_module(module):
context.set_context(mode=context.PYNATIVE_MODE) context.set_context(mode=context.PYNATIVE_MODE)
@ms_function
def refactor_fac(n):
""" grad_refactor_fac """
if n == 0:
return 1
return n * refactor_fac(n-1)
def test_refactor():
res = refactor_fac(3)
assert res == 6
@ms_function @ms_function
def while_upper_bound(upper): def while_upper_bound(upper):
rval = 2 rval = 2
...@@ -386,16 +376,19 @@ def test_grad_while(): ...@@ -386,16 +376,19 @@ def test_grad_while():
assert grad_while(5) == (60,) assert grad_while(5) == (60,)
@ms_function @ms_function
def fac(n): def factorial(n):
""" fac """ """ factorial """
if n == 0: if n == 0:
return 1 return 1
return n * fac(n-1) return n * factorial(n-1)
def test_factorial():
res = factorial(3)
assert res == 6
def test_fac(): def test_grad_factorial():
""" test_fac """ res = C.grad(factorial)(3)
res = fac(4) assert res == 11
assert res == 24
def _for(x): def _for(x):
""" _for """ """ _for """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册