From 9a6a3793467e37387e148af620394118357fedc7 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 6 Jul 2021 15:21:11 +0800 Subject: [PATCH] feat(traced_module): add visit method GitOrigin-RevId: 251ecebf87c94fd5b60c27596a45149d479603e9 --- .../traced_module/traced_module.py | 46 ++++++++++++++++--- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/imperative/python/megengine/experimental/traced_module/traced_module.py b/imperative/python/megengine/experimental/traced_module/traced_module.py index c6faa0305..9b8b8e34b 100644 --- a/imperative/python/megengine/experimental/traced_module/traced_module.py +++ b/imperative/python/megengine/experimental/traced_module/traced_module.py @@ -10,7 +10,7 @@ import collections import copy import functools from inspect import getmembers, isclass, ismethod -from typing import List, Type +from typing import Dict, List, Type from ... import module as M from ...core._imperative_rt.core2 import Tensor as RawTensor @@ -64,6 +64,14 @@ class InternalGraph: def insert(self, expr): self._exprs.append(expr) + @property + def inputs(self): + return self._inputs + + @property + def outputs(self): + return self._outputs + def add_input(self, i): self._inputs.append(i) @@ -271,6 +279,22 @@ class TracedModuleBuilder(NodeMixin): return wrapped +class _expr_list: + def __init__(self, module: "TracedModule"): + self.module = module + + def __iter__(self): + graph = self.module.m_node.graph + for expr in graph._exprs: + if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): + yield expr + assert isinstance(expr.inputs[0].expr, GetAttr) + (obj,) = expr.inputs[0].expr.interpret(self.module) + if isinstance(obj, TracedModule): + yield from obj.exprs + yield expr + + class TracedModule(Module): """ `TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called. @@ -291,14 +315,21 @@ class TracedModule(Module): return rst @property - def all_exprs(self): + def exprs(self): """ - Visit all ``Expr``s in the graph recursively. + Get all ``Expr`` s recursively. - :return: List[Expr] + :return: Iterator[Expr] """ + return _expr_list(self) - in_nodes = [i.expr for i in self.m_node.graph._inputs if not i is self] + def flatten(self): + """ + Get a new module, which eliminates ``GetAttr`` and has no hierarchy. + + :return: :class:`TracedModule` + """ + new_module = copy.deepcopy(self) def _flatten_submodule(module, call=None): if not isinstance(module, TracedModule): @@ -328,6 +359,7 @@ class TracedModule(Module): elif isinstance(expr, CallMethod): obj_node = expr.inputs[0] if isinstance(obj_node, ModuleNode): + assert isinstance(expr.inputs[0].expr, GetAttr) (obj,) = expr.inputs[0].expr.interpret(module) exprs.extend(_flatten_submodule(obj, expr)) else: @@ -337,7 +369,9 @@ class TracedModule(Module): return exprs - return in_nodes + _flatten_submodule(self) + new_module.m_node.graph._exprs = _flatten_submodule(new_module) + + return new_module def __getstate__(self): d = self.__dict__ -- GitLab